簡潔明瞭的tensorflow2.0教程——用keras實現mnist數字識別

通過本文你可以快速學會使用keras搭建神經網絡,只需40行代碼構建神經網絡實現mnist數據集手寫數字的識別。MNIST數據集(Mixed National Institute of Standards and Technology database)是美國國家標準與技術研究院收集整理的大型手寫數字數據庫,包含60,000個示例的訓練集以及10,000個示例的測試集.,圖片大小爲28*28。完整代碼在我的github,鏈接:https://github.com/JohnLeek/Tensorflow-study,倉庫中mnist.npz就是數據集,day3_mnist_reg.py和day3_mnist_train_ex4.py爲完整代碼,覺得不錯的github給個star吧。

一、keras介紹

Keras 是一個用 Python 編寫的高級神經網絡 API,它能夠以 TensorFlowCNTK, 或者 Theano 作爲後端運行。Keras 的開發重點是支持快速的實驗。能夠以最小的時延把你的想法轉換爲實驗結果,是做好研究的關鍵。

如果你在以下情況下需要深度學習庫,請使用 Keras:

  1. 允許簡單而快速的原型設計(由於用戶友好,高度模塊化,可擴展性)。
  2. 同時支持卷積神經網絡和循環神經網絡,以及兩者的組合。
  3. 在 CPU 和 GPU 上無縫運行。

查看相關說明,請訪問https://keras-zh.readthedocs.io/

Keras 兼容的 Python 版本: Python 2.7-3.6

目前tensorflow已經將keras最爲標準的後端庫,這一特點在tf2.0中尤爲明顯,以下我們要講的keras默認爲tensorflow中的keras模塊。

二、用keras搭建神經網絡技巧

我總結了下邊的幾個關鍵字:

D:Data,加載數據集

S:Sequential,搭建我們的神經網絡

C:compile,在運行我們的模型前設置優化器,損失函數等

F:fit,設置神經網絡,傳入訓練集測試集,指定訓練次數

S:summary,運行神經網絡

三、代碼實現

1、首先我們加載數據集(D)

tensorflow提供了數據集,但是需要我們下載,有個問題就是下載速度太慢,這裏我整理好了數據集,放到了我的github,下載好數據集之後,放在C盤,user/.keras/datasets,文件夾下即可,如圖

然後我們開始加載數據集

mnist = tf.keras.datasets.mnist

(x_train,y_train),(x_test,y_test) = mnist.load_data()

然後調整數據集,進行歸一化操作,加快神經網絡收斂速度

x_train,x_test = x_train/255.,x_test/255.

2、搭建我們的神經網絡(S)

model = tf.keras.models.Sequential([

    tf.keras.layers.Flatten(),

    tf.keras.layers.Dense(128,activation = "relu"),

    tf.keras.layers.Dense(10,activation = "softmax")

])

這裏我們首先拉着了神經網絡,利用到了Flatten函數,然後搭建了一個128個輸入節點,激活函數爲relu的輸入層,然後因爲我們要實現0~9數字分類,我們輸出層爲10個神經元,採用softmax使輸出符合概率分佈。

3、設置神經網絡參數(C)

model.compile(optimizer = "adam",

              loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = False),

              metrics = ["sparse_categorical_accuracy"])

這裏我們指點了優化器(optimizer),損失函數(loss),神經網絡準確率評估標準(metrics),要注意因爲我們是執行分類任務並且採用了獨熱碼,所以我們交叉熵損失函數。

4、設置神經網絡數據集相關參數(F)

model.fit(x_train,y_train,batch_size = 32,epochs = 5,validation_data = (x_test,y_test),

                    validation_freq = 1)

這裏我們指定訓練集x_train,y_train,訓練接一次喂入神經網絡數據集大小爲32,訓練次數爲5次,測試集爲(x_test,y_test),每隔一輪驗證準確率。

5、運行我們的模型(S)

model.summary()

6、結果展示

控制檯打印出了我們神經網絡模型,參數數量,準確率(98.47%)

四、進階操作保存我們的模型,加入斷點續訓,保存模型參數

1、爲了方便我們在不同的設備上訓練我們已經訓練好的模型,我們引入了斷點續訓的功能,幫組我們更好的優化神經網絡。我們只需要在代碼中添加這兩句代碼即可。

checkpoint_save_path = "./checkpoint/mnist.ckpt"

if os.path.exists(checkpoint_save_path+".index"):

    print("-----------------load Data---------------")

    model.load_weights(checkpoint_save_path)


cp_callback = tf.keras.callbacks.ModelCheckpoint(

    filepath = checkpoint_save_path,

    save_weights_only = True,

    save_best_only = True

)

這裏我們指定了模型保存路徑: "./checkpoint/mnist.ckpt",如果模型存在我們就在原有的基礎上繼續訓練我們的模型,如果沒有我們在訓練模型的時候保存參數,調用ModelCheckpoint函數,指定我們要保存的參數,這裏我保存了權重,和最優結果。

到這裏還沒結束我們需要對model.fit做一定更改,加一個回調函數,如下:

history = model.fit(x_train,y_train,batch_size = 32,epochs = 5,validation_data = (x_test,y_test),

                    validation_freq = 1,callbacks = [cp_callback])

好了斷點續訓保存模型就完成了,接下來我們保存下神經網絡可訓練參數:

file = open("./weights_variables.txt","w")

for v in model.trainable_variables:

    file.write(str(v.name)+"\n")

    file.write(str(v.shape)+"\n")

    file.write(str(v.numpy())+"\n")

file.close()

2、結果展示

這個就是我們保存好的模型。

這個就是我們神經網絡所有可訓練參數的值。

我們看一下第一次訓練結果。

現在我們再運行下我們的代碼。看看是不是在上一次訓練的基礎上繼續訓練。

可以看到加載了我們保存好的模型,準確率不斷提高

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章