深度學習基礎 - MNIST實驗(Tensorflow-CNN)
本文的完整代碼託管在我的Github PnYuan - Practice-of-Machine-Learning - MNIST_tensorflow_demo,歡迎交流。
1.任務背景
這裏,我們擬通過搭建卷積神經網絡(CNN)來完成MNIST手寫數字識別任務,關於MNIST任務的相關內容可參考前文深度學習基礎 - MNIST實驗(tensorflow+Softmax)或深度學習基礎 - MNIST實驗(tensorflow+MLP)。
2.實驗過程
實驗參考代碼:python + tensorflow: cnn_demo.py & cnn_demo_self_test.py
實驗分三步進行:
- 參考LeNet-5,搭建適用於該任務的CNN模型,開發實現基於tensorflow;
- 加載MNIST數據集,配置超參數,進行訓練與測試,分析效果;
- 加載自制手寫圖片,採用訓練好的CNN進行識別,分析效果;
2.1.CNN建模
LeNet-5是Y.LeCun等人早期所設計的一種CNN,是經典的神經網絡架構之一,如下圖所示:(參考原文獻)
本實驗採用python-tensorflow實現LeNet-5,其建模代碼樣例如下:
'''construction of leNet-5 model'''
def lenet_5_forward_propagation(X):
"""
@note: construction of leNet-5 forward computation graph:
CONV1 -> MAXPOOL1 -> CONV2 -> MAXPOOL2 -> FC3 -> FC4 -> SOFTMAX
@param X: input dataset placeholder, of shape (number of examples (m), input size)
@return: A_l, the output of the softmax layer, of shape (number of examples, output size)
"""
# reshape imput as [number of examples (m), weight, height, channel]
X_ = tf.reshape(X, [-1, 28, 28, 1]) # num_channel = 1 (gray image)
### CONV1 (f = 5*5*1, n_f = 6, s = 1, p = 'same')
W_conv1 = weight_variable(shape = [5, 5, 1, 6])
b_conv1 = bias_variable(shape = [6])
# shape of A_conv1 ~ [m,28,28,6]
A_conv1 = tf.nn.relu(tf.nn.conv2d(X_, W_conv1, strides = [1, 1, 1, 1], padding = 'SAME') + b_conv1)
### MAXPOOL1 (f = 2*2*1, s = 2, p = 'same')
# shape of A_pool1 ~ [m,14,14,6]
A_pool1 = tf.nn.max_pool(A_conv1, ksize = [1, 2, 2, 1], strides=[1, 2, 2, 1], padding = 'SAME')
### CONV2 (f = 5*5*1, n_f = 16, s = 1, p = 'same')
W_conv2 = weight_variable(shape = [5, 5, 6, 16])
b_conv2 = bias_variable(shape = [16])
# shape of A_conv2 ~ [m,10,10,16]
A_conv2 = tf.nn.relu(tf.nn.conv2d(A_pool1, W_conv2, strides = [1, 1, 1, 1], padding = 'VALID') + b_conv2)
### MAXPOOL2 (f = 2*2*1, s = 2, p = 'same')
# shape of A_pool2~ [m,5,5,16]
A_pool2 = tf.nn.max_pool(A_conv2, ksize = [1, 2, 2, 1], strides=[1, 2, 2, 1], padding = 'SAME')
### FC3 (n = 120)
# flatten the volumn to vector
A_pool2_flat = tf.reshape(A_pool2, [-1, 5*5*16])
W_fc3 = weight_variable([5*5*16, 120])
b_fc3 = bias_variable([120])
# shape of A_fc3 ~ [m,120]
A_fc3 = tf.nn.relu(tf.matmul(A_pool2_flat, W_fc3) + b_fc3)
### FC4 (n = 84)
W_fc4 = weight_variable([120, 84])
b_fc4 = bias_variable([84])
# shape of A_fc4 ~ [m, 84]
A_fc4 = tf.nn.relu(tf.matmul(A_fc3, W_fc4) + b_fc4)
# Softmax (n = 10)
W_l = weight_variable([84, 10])
b_l = bias_variable([10])
# shape of A_l ~ [m,10]
A_l=tf.nn.softmax(tf.matmul(A_fc4, W_l) + b_l)
return A_l
2.2.訓練與測試
設置優化策略及相關超參數(如learning_rate
、num_epochs
、mini-batch size
等),進行訓練,經過一段時間的訓練,得出的accuracy
結果如下:
Train Accuracy: 0.9920
Valid Accuracy: 0.9896
Test Accuracy: 0.9881
同時該訓練期間,指標accuracy
和cost
的變化過程如下圖示:
可以看出,此處CNN(LeNet-5)已經取得了不錯的結果(≈99%的測試準確率)。而通過觀察訓練曲線變化趨勢,猜測隨着迭代的繼續,模型效果還可繼續提升。
2.3.實測
接下來驗證該CNN模型在生活場景下的泛化效果,筆者此處在實驗室即興寫了若干待識別數字,示意如下:
採用之前所訓練的CNN,得出預測結果示意如下:
結果中出現了一些識別錯誤,初步猜測是由數據分佈的差異所引起。可以考慮在圖像訓練和測試時,先採用更多的預處理手段(如灰度歸一化、對比度增強、閾值分割…),從而使分佈接近。降低模型遷移難度。
3.實驗小結
本文采用CNN模型進行mnist手寫數字識別任務,取得了很好的效果(99%的測試準確率)。同時採用訓練好的模型識別了實際場景中的數字,體現了一定的識別效果。
4.參考資料
官方參考:
CNN模型:
開發輔助: