手寫數字識別Mnist的Pytorch實現

手寫數字識別Mnist的Pytorch實現

注:該內容爲校內課程實驗,僅供參考,請勿抄襲!
源碼地址:Gray-scale-Hand-Written-Digits-Pytorch

一、引言(Introduction)

  手寫數字識別時經典的圖像分類任務,也是經典的有監督學習任務,經常被用於測試圖像的特徵提取效果、分類器性能度量等方面,本文將通過應用機器學習和深度學習算法實現手寫數字識別。
  圖像分類任務是指給定一張圖像來對其進行分類,常見的圖像分類任務有手寫數字識別、貓狗分類、物品識別等,圖像分類也是計算機視覺基本的分類任務。而對於手寫數字識別任務來說,可以當做圖像分類問題,也可以當做目標檢測與分類。其中圖像分類即輸入整個圖像並預測其類別,例如含有6的數字的圖像我們希望最大化預測爲6的概率;另外也可以視爲目標檢測任務,即提取圖像中的目標並將目標提取出後進行預測,例如OCR對字符進行識別。因爲手寫數字是被預處理後的圖像,且一張圖像中只包含一個數字,因此本文則將手寫數字識別視爲整個圖像的分類。

二、任務分析

2.1 形式化描述
  給定一個圖像數據集,其中圖像記做,是一個寬爲,高爲,通道數爲的圖像,是圖像對應的類標,任務的目標是尋找一個由圖像數據到類別的映射。
2.2 任務分析
  傳統的方法是對圖像進行序列化,即使用一組向量來對圖像進行表示。例如本文處理的手寫數字識別是灰度圖像,即通道數爲1,寬高均爲28像素的圖像,因此可以直接將圖像的每個像素使用0-255整型數進行表示,並形成784維度的向量,然後使用包括SVM(支持向量機)、LR(邏輯迴歸)、DT(決策樹)等機器學習學習多個超平面將假設空間中的樣本正確的分類。另外也可以使用聚類算法,例如KNN、K-means、DBSCAN等算法自動將樣本聚到10個類別上。
  另外由於手寫數字相同類別之間會存在相關性,因此也有基於圖像壓縮方法進行特徵提取工作。通常使用PCA等降維技術將784維度的圖像降維到較低空間,形成潛在的特徵向量,且這些向量每一個維度之間是不相關;其次對壓縮後的特徵向量在使用機器學習算法進行分類,這種方法可以大大提高對重要特徵的學習,忽略噪聲對分類的影響。
  隨着深度學習的發展,基於深度學習神經網絡可以自動地對特徵進行提取以及分類成爲圖像分類的主流方法。常規有直接將圖像對應的矩陣(或張量)進行展開後直接喂入一個前饋神經網絡,或使用卷積神經網絡或膠囊網絡對特徵進行提取,並使用一層前饋神經網絡進行分類。基於深度學習的方法通常可以有效的提升分類的性能和精度。
本文主要進行了簡單的對比實驗,對比方案包括機器學習算法(KNN算法和決策樹算法)以及深度學習算法(神經網絡、CNN),並進行可視化展示。機器學習算法在實驗1和2中有所介紹,因此本節主要介紹CNN網絡:
  CNN爲兩層卷積層以及池化層。第一層卷積層爲32個大小爲33的卷積核,第二層卷積層爲64個22的卷積核,兩個池化層均爲2值最大池化,卷積網絡則爲3136維度的向量,輸出層爲兩層神經網絡,網絡中使用正則化防止過擬合,輸出部分爲softmax。

三、數據描述

  本次實驗使用MNIST數據集進行實驗,其中我們用6000張圖像作爲訓練集,1000張圖像作爲測試集,圖像的示例如圖所示:

在這裏插入圖片描述
  由於數據集已經集成於一些深度學習框架中,因此我們直接使用pytorch的torchvision中的datasets庫下載相應數據集。數據集包括如下幾個文件,如下所示:

train-images-idx3-ubyte 訓練集圖像數據二進制文件
train-labels-idx1-ubyte 訓練集對應類標二進制文件
t10k-images-idx3-ubyte 測試集10K圖像數據二進制文件
t10k-labels-idx1-ubyte 測試集10K對應類標二進制文件

  數據集是二進制文件,因此我們使用Pytorch讀取數據集,並直接轉換爲張量,其次將每張圖像與類標存入JSON數據中,保存爲“{‘img’: img, ‘label’: label}”格式。另外我們使用min_batch方法進行訓練,因此使用Pytorch提供的DataLoader方法自動生成batch。

四、實驗

  實驗中,首先使用Sklearn調用了包括KNN(K近鄰)和DT(決策樹)兩個算法並對訓練集進行訓練,其次在測試集上進行實驗:。其次使用Pytorch實現只有一層隱層的神經網絡以及含有多層卷積核池化層的CNN網絡進行實驗,程序劃分爲基於機器學習的訓練入口函數(ml_main.py),機器學習算法類爲Classifier.py;基於深度學習的訓練入口函數(dl_main.py)以及模型爲Network.py。使用機器學習的算法實驗結果如表所示:

算法 精確度
KNN 97.50%
DT 75.96%
SVM 97.92%

  在使用深度學習訓練時,相關超參數如表2所示:

超參數 取值
Epoch 20
Batch_size 30
Learn_rate 0.01
Hidden_size 196

基於深度學習的實驗結果如表4所示:

算法 精確度
單隱層神經網絡 93.86%
CNN 98.86%

  下圖展示了在CNN模型下訓練和測試過程中的損失與精度的變化曲線,以展示最優的CNN的收斂情況。
在這裏插入圖片描述

  其中橫座標表示統計的次數,訓練集的loss和acc則是每訓練20個batch統計一次,測試集的loss和test則是每1000個batch記錄一次。

五、總結

  通過使用幾個簡單的機器學習和深度學習算法實現了對手寫數字識別數據集MNIST的分類,可以發現機器學習算法中SVM模型表現最優,在深度學習模型中CNN分類效果最優。另外通過對模型訓練過程中的收斂情況可知,當訓練第3輪時模型以及基本達到收斂,因此可知模型的收斂速度和收斂性得以保證。在今後的拓展實驗中,我們還將會對彩色圖像以及場景圖像進行識別,以提升模型的魯棒性。

參考文獻
[1]Simonyan K, Zisserman A. VERY DEEP CONVOLUTIONAL NETWORKS FOR LARGE-SCALE IMAGE RECOGNITION[C]. computer vision and pattern recognition, 2014.
[2]He K, Zhang X, Ren S, et al. Deep Residual Learning for Image Recognition[C]. computer vision and pattern recognition, 2016: 770-778.
[3]張黎;劉爭鳴;唐軍;;基於BP神經網絡的手寫數字識別方法的實現[J];自動化與儀器儀表;2015年06期

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章