深度學習基礎 - MNIST實驗(tensorflow+CNN)

深度學習基礎 - MNIST實驗(Tensorflow-CNN)


本文的完整代碼託管在我的Github PnYuan - Practice-of-Machine-Learning - MNIST_tensorflow_demo,歡迎交流。

1.任務背景

這裏,我們擬通過搭建卷積神經網絡(CNN)來完成MNIST手寫數字識別任務,關於MNIST任務的相關內容可參考前文深度學習基礎 - MNIST實驗(tensorflow+Softmax)深度學習基礎 - MNIST實驗(tensorflow+MLP)

2.實驗過程

實驗參考代碼:python + tensorflow: cnn_demo.py & cnn_demo_self_test.py

實驗分三步進行:

  1. 參考LeNet-5,搭建適用於該任務的CNN模型,開發實現基於tensorflow;
  2. 加載MNIST數據集,配置超參數,進行訓練與測試,分析效果;
  3. 加載自制手寫圖片,採用訓練好的CNN進行識別,分析效果;

2.1.CNN建模

LeNet-5是Y.LeCun等人早期所設計的一種CNN,是經典的神經網絡架構之一,如下圖所示:(參考原文獻)

lenet-5_graph

本實驗採用python-tensorflow實現LeNet-5,其建模代碼樣例如下:

'''construction of leNet-5 model'''
def lenet_5_forward_propagation(X):
    """
    @note: construction of leNet-5 forward computation graph:
        CONV1 -> MAXPOOL1 -> CONV2 -> MAXPOOL2 -> FC3 -> FC4 -> SOFTMAX

    @param X: input dataset placeholder, of shape (number of examples (m), input size)

    @return: A_l, the output of the softmax layer, of shape (number of examples, output size)
    """

    # reshape imput as [number of examples (m), weight, height, channel]
    X_ = tf.reshape(X, [-1, 28, 28, 1])  # num_channel = 1 (gray image)

    ### CONV1 (f = 5*5*1, n_f = 6, s = 1, p = 'same')
    W_conv1 = weight_variable(shape = [5, 5, 1, 6])
    b_conv1 = bias_variable(shape = [6])
    # shape of A_conv1 ~ [m,28,28,6]
    A_conv1 = tf.nn.relu(tf.nn.conv2d(X_, W_conv1, strides = [1, 1, 1, 1], padding = 'SAME') + b_conv1)

    ### MAXPOOL1 (f = 2*2*1, s = 2, p = 'same')
    # shape of A_pool1 ~ [m,14,14,6]
    A_pool1 = tf.nn.max_pool(A_conv1, ksize = [1, 2, 2, 1], strides=[1, 2, 2, 1], padding = 'SAME')

    ### CONV2 (f = 5*5*1, n_f = 16, s = 1, p = 'same')
    W_conv2 = weight_variable(shape = [5, 5, 6, 16])
    b_conv2 = bias_variable(shape = [16])    
    # shape of A_conv2 ~ [m,10,10,16]
    A_conv2 = tf.nn.relu(tf.nn.conv2d(A_pool1, W_conv2, strides = [1, 1, 1, 1], padding = 'VALID') + b_conv2)    

    ### MAXPOOL2 (f = 2*2*1, s = 2, p = 'same')  
    # shape of A_pool2~ [m,5,5,16]
    A_pool2 = tf.nn.max_pool(A_conv2, ksize = [1, 2, 2, 1], strides=[1, 2, 2, 1], padding = 'SAME')

    ### FC3 (n = 120)
    # flatten the volumn to vector
    A_pool2_flat = tf.reshape(A_pool2, [-1, 5*5*16])

    W_fc3 = weight_variable([5*5*16, 120])
    b_fc3 = bias_variable([120])
    # shape of A_fc3 ~ [m,120]
    A_fc3 = tf.nn.relu(tf.matmul(A_pool2_flat, W_fc3) + b_fc3)

    ### FC4 (n = 84)
    W_fc4 = weight_variable([120, 84])
    b_fc4 = bias_variable([84])
    # shape of A_fc4 ~ [m, 84]
    A_fc4 = tf.nn.relu(tf.matmul(A_fc3, W_fc4) + b_fc4)

    # Softmax (n = 10)
    W_l = weight_variable([84, 10])
    b_l = bias_variable([10])
    # shape of A_l ~ [m,10]
    A_l=tf.nn.softmax(tf.matmul(A_fc4, W_l) + b_l)

    return A_l

2.2.訓練與測試

設置優化策略及相關超參數(如learning_ratenum_epochsmini-batch size等),進行訓練,經過一段時間的訓練,得出的accuracy結果如下:

Train Accuracy: 0.9920
Valid Accuracy: 0.9896
Test Accuracy: 0.9881

同時該訓練期間,指標accuracycost的變化過程如下圖示:

cnn_training_curve

可以看出,此處CNN(LeNet-5)已經取得了不錯的結果(≈99%的測試準確率)。而通過觀察訓練曲線變化趨勢,猜測隨着迭代的繼續,模型效果還可繼續提升。

2.3.實測

接下來驗證該CNN模型在生活場景下的泛化效果,筆者此處在實驗室即興寫了若干待識別數字,示意如下:

cnn_training_curve

採用之前所訓練的CNN,得出預測結果示意如下:

cnn_training_curve

結果中出現了一些識別錯誤,初步猜測是由數據分佈的差異所引起。可以考慮在圖像訓練和測試時,先採用更多的預處理手段(如灰度歸一化、對比度增強、閾值分割…),從而使分佈接近。降低模型遷移難度。

3.實驗小結

本文采用CNN模型進行mnist手寫數字識別任務,取得了很好的效果(99%的測試準確率)。同時採用訓練好的模型識別了實際場景中的數字,體現了一定的識別效果。

4.參考資料

官方參考:

CNN模型:

開發輔助:

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章