fashion_mnist识别

示例一:

1.导入依赖包

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Conv2D,MaxPool2D,Flatten,Dense

2.加载数据集并改变图片维度

(train_x, train_y), (test_x, test_y) = tf.keras.datasets.fashion_mnist.load_data()
train_x = np.expand_dims(train_x, -1)
test_x = np.expand_dims(test_x, -1)

3.搭建模型

model = tf.keras.Sequential()
model.add(Conv2D(24, kernel_size=5, padding='same', activation='relu',
                 input_shape=(28, 28, 1)))
model.add(MaxPool2D())

model.add(Conv2D(48, kernel_size=5, padding='same', activation='relu'))
model.add(MaxPool2D())

model.add(Conv2D(48, kernel_size=5, padding='same', activation='relu'))
model.add(MaxPool2D())

model.add(Conv2D(48, kernel_size=5, padding='same', activation='relu'))
model.add(MaxPool2D())

model.add(Conv2D(64, kernel_size=5, padding='same', activation='relu'))
model.add(MaxPool2D(padding='same'))

model.add(Conv2D(64, kernel_size=5, padding='same', activation='relu'))
model.add(MaxPool2D(padding='same'))

model.add(Conv2D(64, kernel_size=5, padding='same', activation='relu'))
model.add(MaxPool2D(padding='same'))

model.add(Flatten())
model.add(Dense(256, activation='relu'))
model.add(Dense(10, activation='softmax'))

4.定义优化器、损失函数和评价指标

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['acc'])

5.训练模型

history = model.fit(train_x, train_y, epochs=5, validation_data=(test_x, test_y))
print(history.history.keys())

6.绘制准确率变化情况

plt.plot(history.epoch, history.history.get('acc'), label='acc')
plt.plot(history.epoch, history.history.get('val_acc'), label='val_acc')
plt.legend()
plt.show()

7.绘制模型损失变化情况

plt.plot(history.epoch, history.history.get('loss'), label='loss')
plt.plot(history.epoch, history.history.get('val_loss'), label='val_loss')
plt.legend()
plt.show()

示例二:增加数据增广

1.导入依赖包

import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import Conv2D,Flatten,Dense,BatchNormalization,Dropout

2.加载数据集并改变图片维度

(train_x, train_y), (test_x, test_y) = tf.keras.datasets.fashion_mnist.load_data()
train_x = np.expand_dims(train_x, -1)
test_x = np.expand_dims(test_x, -1)

3.数据增广

datagen = ImageDataGenerator(
    rotation_range=10,
    zoom_range=0.10,
    width_shift_range=0.1,
    height_shift_range=0.1)

4.定义模型

model = tf.keras.Sequential()

model.add(Conv2D(32, kernel_size=3, activation='relu', input_shape=(28, 28, 1)))
model.add(BatchNormalization())
model.add(Conv2D(32, kernel_size=3, activation='relu'))
model.add(BatchNormalization())
model.add(Conv2D(32, kernel_size=5, strides=2, padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.4))

model.add(Conv2D(64, kernel_size=3, activation='relu'))
model.add(BatchNormalization())
model.add(Conv2D(64, kernel_size=3, activation='relu'))
model.add(BatchNormalization())
model.add(Conv2D(64, kernel_size=5, strides=2, padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.4))

model.add(Conv2D(128, kernel_size=4, activation='relu'))
model.add(BatchNormalization())
model.add(Flatten())
model.add(Dropout(0.4))
model.add(Dense(10, activation='softmax'))

5.定义优化器、损失函数和评价指标

model.compile(optimizer='adam', 
              loss='sparse_categorical_crossentropy', 
              metrics=['acc'])

6.将图片进行数据增广

datagen.fit(train_x)

7.训练模型

history = model.fit(datagen.flow(train_x, train_y, batch_size=32),
                    epochs=5, validation_data=(test_x, test_y))
print(history.history.keys())

8.绘制准确率变化情况

plt.plot(history.epoch, history.history.get('acc'), label='acc')
plt.plot(history.epoch, history.history.get('val_acc'), label='val_acc')
plt.legend()
plt.show()

9.绘制模型损失变化情况

plt.plot(history.epoch, history.history.get('loss'), label='loss')
plt.plot(history.epoch, history.history.get('val_loss'), label='val_loss')
plt.legend()
plt.show()

转载请注明来源,欢迎对文章中的引用来源进行考证,欢迎指出任何有错误或不够清晰的表达。可以在下面评论区评论,也可以邮件至 2621041184@qq.com