LeNet5算法实现

导入依赖包

import tensorflow as tf
from tensorflow.keras.layers import Conv2D, MaxPool2D, Flatten, Dense
import numpy as np
import matplotlib.pyplot as plt
print(tf.__version__)
2.4.0

加载数据集

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()
print(train_images.shape, train_labels.shape)
print(test_images.shape, test_labels.shape)
(50000, 32, 32, 3) (50000, 1)
(10000, 32, 32, 3) (10000, 1)

查看训练集图片

fig, ax = plt.subplots(nrows=4, ncols=5, sharex='all', sharey='all')
ax = ax.flatten()
for i in range(20):
    img = train_images[i]
    ax[i].imshow(img)
ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.show()

查看标签类别

labels = np.unique(train_labels)
print(labels)
[0 1 2 3 4 5 6 7 8 9]

对图片进行归一化

train_images, test_images = train_images / 255.0, test_images / 255.0
print(train_images.shape, test_images.shape)
(50000, 32, 32, 3) (10000, 32, 32, 3)

搭建模型

model = tf.keras.models.Sequential([
    Conv2D(filters=6, kernel_size=(5, 5), activation='sigmoid', input_shape=(32, 32, 3)),
    MaxPool2D(pool_size=(2, 2), strides=2),
    Conv2D(filters=16, kernel_size=(5, 5), activation='sigmoid'),
    MaxPool2D(pool_size=(2, 2), strides=2),
    Flatten(),
    Dense(120, activation='sigmoid'),
    Dense(84, activation='sigmoid'),
    Dense(10, activation='softmax')
])

指定优化器、损失函数、评价指标

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['acc'])

训练模型

history = model.fit(train_images, train_labels,
                    epochs=10,
                    validation_data=(test_images, test_labels))
Epoch 1/10
1563/1563 [==============================] - 35s 22ms/step - loss: 2.1636 - acc: 0.1792 - val_loss: 1.8696 - val_acc: 0.3045
Epoch 2/10
1563/1563 [==============================] - 33s 21ms/step - loss: 1.8480 - acc: 0.3158 - val_loss: 1.6941 - val_acc: 0.3849
Epoch 3/10
1563/1563 [==============================] - 34s 22ms/step - loss: 1.6897 - acc: 0.3837 - val_loss: 1.5897 - val_acc: 0.4275
Epoch 4/10
1563/1563 [==============================] - 33s 21ms/step - loss: 1.5677 - acc: 0.4330 - val_loss: 1.4668 - val_acc: 0.4658
Epoch 5/10
1563/1563 [==============================] - 32s 21ms/step - loss: 1.4769 - acc: 0.4680 - val_loss: 1.4604 - val_acc: 0.4709
Epoch 6/10
1563/1563 [==============================] - 33s 21ms/step - loss: 1.4172 - acc: 0.4878 - val_loss: 1.3686 - val_acc: 0.5014
Epoch 7/10
1563/1563 [==============================] - 31s 20ms/step - loss: 1.3699 - acc: 0.5006 - val_loss: 1.3381 - val_acc: 0.5147
Epoch 8/10
1563/1563 [==============================] - 33s 21ms/step - loss: 1.3409 - acc: 0.5160 - val_loss: 1.3198 - val_acc: 0.5230
Epoch 9/10
1563/1563 [==============================] - 33s 21ms/step - loss: 1.2986 - acc: 0.5288 - val_loss: 1.2825 - val_acc: 0.5385
Epoch 10/10
1563/1563 [==============================] - 32s 21ms/step - loss: 1.2629 - acc: 0.5442 - val_loss: 1.2597 - val_acc: 0.5472

查看模型结构

model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 28, 28, 6)         456       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 14, 14, 6)         0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 10, 10, 16)        2416      
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 5, 5, 16)          0         
_________________________________________________________________
flatten (Flatten)            (None, 400)               0         
_________________________________________________________________
dense (Dense)                (None, 120)               48120     
_________________________________________________________________
dense_1 (Dense)              (None, 84)                10164     
_________________________________________________________________
dense_2 (Dense)              (None, 10)                850       
=================================================================
Total params: 62,006
Trainable params: 62,006
Non-trainable params: 0
_________________________________________________________________

显示模型准确率的变化过程

plt.plot(history.epoch, history.history.get('acc'), label='acc')
plt.plot(history.epoch, history.history.get('val_acc'), label='val_acc')
plt.legend()
plt.show()

显示模型损失的变化过程

plt.plot(history.epoch, history.history.get('loss'), label='loss')
plt.plot(history.epoch, history.history.get('val_loss'), label='val_loss')
plt.legend()
plt.show()

转载请注明来源,欢迎对文章中的引用来源进行考证,欢迎指出任何有错误或不够清晰的表达。可以在下面评论区评论,也可以邮件至 2621041184@qq.com