cifar10数据集训练

下载数据集

Cifar10数据集总共有6万张32*32像素点的彩色图片和标签,涵盖十个分类:飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船、卡车。git

其中5万张用于训练,1万张用于测试。网络

 

import tensorflow as tf
from tensorflow import keras
from matplotlib import pyplot as plt
import numpy as np
from tensorflow.keras.layers import Conv2D, MaxPool2D, Flatten, Dense,Dropout

cifar10 = keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

 

搭建网络结构

model = keras.models.Sequential([
    Conv2D(128, (3, 3), activation='relu',padding='same'),
    keras.layers.BatchNormalization(),
    MaxPool2D((2, 2)),
    Dropout(0.3),
    Conv2D(256, (3, 3), activation='relu',padding='same'),
    keras.layers.BatchNormalization(),
    MaxPool2D((2, 2)),
    Dropout(0.3),
    Conv2D(512, (3, 3), activation='relu',padding='same'),
    keras.layers.BatchNormalization(),
    MaxPool2D((2, 2)),
    Flatten(),
    Dropout(0.5),
    Dense(512, activation='relu', kernel_regularizer=keras.regularizers.l2(0.1)),
    Dropout(0.5),
    Dense(10, activation='softmax')
])

 

编译模型

model.compile(optimizer=keras.optimizers.Adam(lr=0.0001), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=['accuracy'])

 

训练模型

history = model.fit(x_train, y_train, epochs=100, batch_size=16,verbose=1,validation_data=(x_test, y_test),validation_freq=1)

 

可视化acc/loss曲线

#显示训练集和测试集的acc和loss曲线
plt.rcParams['font.sans-serif']=['SimHei']
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

plt.subplot(1, 2, 1)
plt.plot(acc, label='训练Acc')
plt.plot(val_acc, label='测试Acc')
plt.title('Acc曲线')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(loss, label='训练Loss')
plt.plot(val_loss, label='测试Loss')
plt.title('Loss曲线')
plt.legend()
plt.show()