有了之前搭建模型並訓練的基礎,就可以來個實際的例子了,新手必備MNIST登場。
實現一個手寫數字識別非常簡單,數據集是準備好的,一行代碼下載即可,然後按照之前的方式搭建模型就能訓練出一個識別手寫數字效果還不錯的模型了。
Step1:加載tensorflow並下載數據集
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
mnist = tf.keras.datasets.mnist # 下載
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0 # 歸一化處理
Step2:搭建模型並設置訓練流程
model = tf.keras.Sequential()
model.add(layers.Flatten(input_shape=(28, 28)))
model.add(layers.Dense(128, activation='relu'))
model.add(layers.Dropout(0.2))
model.add(layers.Dense(10, activation='softmax'))
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
Step3:啓動訓練並驗證模型
model.fit(x_train, y_train, epochs=10)
model.evaluate(x_test, y_test, verbose=2)
可以看到使用全連接層的模型acc也能達到98%。
現在結合前面的知識,修改一下模型。
Step4:修改模型,用卷積層