通過本文你可以快速學會使用keras搭建神經網絡,只需40行代碼構建神經網絡實現mnist數據集手寫數字的識別。MNIST數據集(Mixed National Institute of Standards and Technology database)是美國國家標準與技術研究院收集整理的大型手寫數字數據庫,包含60,000個示例的訓練集以及10,000個示例的測試集.,圖片大小爲28*28。完整代碼在我的github,鏈接:https://github.com/JohnLeek/Tensorflow-study,倉庫中mnist.npz就是數據集,day3_mnist_reg.py和day3_mnist_train_ex4.py爲完整代碼,覺得不錯的github給個star吧。
Keras 是一個用 Python 編寫的高級神經網絡 API,它能夠以 TensorFlow, CNTK, 或者 Theano 作爲後端運行。Keras 的開發重點是支持快速的實驗。能夠以最小的時延把你的想法轉換爲實驗結果,是做好研究的關鍵。
查看相關說明,請訪問https://keras-zh.readthedocs.io/。
Keras 兼容的 Python 版本: Python 2.7-3.6。
目前tensorflow已經將keras最爲標準的後端庫,這一特點在tf2.0中尤爲明顯,以下我們要講的keras默認爲tensorflow中的keras模塊。
C:compile,在運行我們的模型前設置優化器,損失函數等
tensorflow提供了數據集,但是需要我們下載,有個問題就是下載速度太慢,這裏我整理好了數據集,放到了我的github,下載好數據集之後,放在C盤,user/.keras/datasets,文件夾下即可,如圖
mnist = tf.keras.datasets.mnist
(x_train,y_train),(x_test,y_test) = mnist.load_data()
x_train,x_test = x_train/255.,x_test/255.
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128,activation = "relu"),
tf.keras.layers.Dense(10,activation = "softmax")
])
這裏我們首先拉着了神經網絡,利用到了Flatten函數,然後搭建了一個128個輸入節點,激活函數爲relu的輸入層,然後因爲我們要實現0~9數字分類,我們輸出層爲10個神經元,採用softmax使輸出符合概率分佈。
model.compile(optimizer = "adam",
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = False),
metrics = ["sparse_categorical_accuracy"])
這裏我們指點了優化器(optimizer),損失函數(loss),神經網絡準確率評估標準(metrics),要注意因爲我們是執行分類任務並且採用了獨熱碼,所以我們交叉熵損失函數。
model.fit(x_train,y_train,batch_size = 32,epochs = 5,validation_data = (x_test,y_test),
validation_freq = 1)
這裏我們指定訓練集x_train,y_train,訓練接一次喂入神經網絡數據集大小爲32,訓練次數爲5次,測試集爲(x_test,y_test),每隔一輪驗證準確率。
model.summary()
控制檯打印出了我們神經網絡模型,參數數量,準確率(98.47%)
1、爲了方便我們在不同的設備上訓練我們已經訓練好的模型,我們引入了斷點續訓的功能,幫組我們更好的優化神經網絡。我們只需要在代碼中添加這兩句代碼即可。
checkpoint_save_path = "./checkpoint/mnist.ckpt"
if os.path.exists(checkpoint_save_path+".index"):
print("-----------------load Data---------------")
model.load_weights(checkpoint_save_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath = checkpoint_save_path,
save_weights_only = True,
save_best_only = True
)
這裏我們指定了模型保存路徑: "./checkpoint/mnist.ckpt",如果模型存在我們就在原有的基礎上繼續訓練我們的模型,如果沒有我們在訓練模型的時候保存參數,調用ModelCheckpoint函數,指定我們要保存的參數,這裏我保存了權重,和最優結果。
到這裏還沒結束我們需要對model.fit做一定更改,加一個回調函數,如下:
history = model.fit(x_train,y_train,batch_size = 32,epochs = 5,validation_data = (x_test,y_test),
validation_freq = 1,callbacks = [cp_callback])
好了斷點續訓保存模型就完成了,接下來我們保存下神經網絡可訓練參數:
file = open("./weights_variables.txt","w")
for v in model.trainable_variables:
file.write(str(v.name)+"\n")
file.write(str(v.shape)+"\n")
file.write(str(v.numpy())+"\n")
file.close()