深度學習(一)之MNIST數據集分類

任務目標

對MNIST手寫數字數據集進行訓練和評估,最終使得模型能夠在測試集上達到\(98\%\)的正確率。(最終本文達到了\(99.36\%\)

使用的庫的版本:

  1. python:3.8.12
  2. pytorch:1.5.1

代碼地址GitHub:https://github.com/xiaohuiduan/deeplearning-study/tree/main/手寫數字識別

數據集介紹

MNIST數字數據集來自MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges

在torchvision中自帶了關於MNIST的數據集。如果直接使用自帶的數據集,能方便不少。關於具體使用,可參考:PyTorch初探MNIST數據集 - 知乎 (zhihu.com)

在Lecun的提供的MNIST數據集,有如下4個文件(images文件和labels文件):

training set包含了60000張手寫數字圖片,test set包含了10000張圖片。在images文件和labels文件中,數據是使用二進制進行保存的。

圖像文件的二進制儲存格式如下(參考python處理MNIST數據集 - 簡書 (jianshu.com)):

  • 第1-4個byte(字節,1byte=8bit),即前32bit存的是文件的magic number,對應的十進制大小是2051;

  • 第5-8個byte存的是number of images,即圖像數量60000;

  • 第9-12個byte存的是每張圖片行數/高度,即28;

  • 第13-16個byte存的是每張圖片的列數/寬度,即28。

  • 從第17個byte開始,每個byte存儲一張圖片中的一個像素點的值。

標籤文件的二進制儲存格式如下(參考python處理MNIST數據集 - 簡書 (jianshu.com)):

  • 第1-4個byte存的是文件的magic number,對應的十進制大小是2049;

  • 第5-8個byte存的是number of items,即label數量60000;

  • 從第9個byte開始,每個byte存一個圖片的label信息,即數字0-9中的一個。

二進制文件的Python處理代碼:

import numpy as np
def read_image(file_path):
    """讀取MNIST圖片

    Args:
        file_path (str): 圖片文件位置

    Returns:
        list: 圖片列表
    """
    with open(file_path,'rb') as f:
        file = f.read()
        img_num = int.from_bytes(file[4:8],byteorder='big') #圖片數量
        img_h = int.from_bytes(file[8:12],byteorder='big') #圖片h
        img_w = int.from_bytes(file[12:16],byteorder='big') #圖片w
        img_data = []
        file = file[16:]
        data_len = img_h*img_w

        for i in range(img_num):
            data = [item/255 for item in file[i*data_len:(i+1)*data_len]]
            img_data.append(np.array(data).reshape(img_h,img_w))

        return img_data

def read_label(file_path):
    with open(file_path,'rb') as f:
        file = f.read()
        label_num = int.from_bytes(file[4:8],byteorder='big') #label的數量
        file = file[8:]
        label_data = []
        for i in range(label_num):
            label_data.append(file[i])
        return label_data


train_img  = read_image("mnist/train/train-images.idx3-ubyte")
train_label = read_label("mnist/train/train-labels.idx1-ubyte")

# test_img = read_image("mnist/test/t10k-images.idx3-ubyte")
# test_label = read_label("mnist/test/t10k-labels.idx1-ubyte")

數據集部分數據如下所示:

數據集劃分

在深度學習中,需要將trainset劃分成訓練集驗證集。最終使用測試集去驗證模型的結果。

訓練集:用來訓練模型參數。

驗證集:驗證模型的狀況和收斂情況。

測試集:驗證模型結果。

形象上來說訓練集就像是學生的課本,學生 根據課本里的內容來掌握知識,驗證集就像是作業,通過作業可以知道 不同學生學習情況、進步的速度快慢,而最終的測試集就像是考試,考的題是平常都沒有見過,考察學生舉一反三的能力。

來源:訓練集(train)驗證集(validation)測試集(test)與交叉驗證法 - 知乎 (zhihu.com)

因此,需要將上文中的train_img,train_label進行劃分,劃分爲訓練集驗證集。這裏使用sklearn中的train_test_split進行劃分,訓練集和測試集的比例爲\(8:2\)

from sklearn.model_selection import train_test_split
train_img,valid_img,train_label,valid_label = train_test_split(train_img,train_label,test_size=0.2,shuffle=True)

網絡結構

根據網絡的權重,Netron生成的網絡結構圖如下,圖中詳細的介紹了每一層的結構參數。

網絡結構的簡潔圖如下所示,網絡一共由3層卷積層(每層卷積分別由Conv2d,BatchNorm2d,MaxPool2d和Dropout構成)和2個全連接層構成。

Pytorch代碼如下:

class MyNet(nn.Module):
    def __init__(self):
        super(MyNet,self).__init__()
        self.conv_1 = nn.Sequential(
            nn.Conv2d(1,32,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2,2),
            nn.Dropout(0.25)
        )
        self.conv_2 = nn.Sequential(
            nn.Conv2d(32,64,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2,2),
            nn.Dropout(0.25),
        )

        self.conv_3 = nn.Sequential(
            nn.Conv2d(64,128,kernel_size=3),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2,2),
            nn.Dropout(0.25),
        )

        self.fc = nn.Sequential(
            nn.Linear(512,128),
            nn.Linear(128,10)
        )

    def forward(self,x): #x (3,28,28)
        x = self.conv_1(x) #x (32,14,14)
        x = self.conv_2(x) #x (64,7,7)
        x = self.conv_3(x) #x (128,4,4)
        x = x.view(x.size(0),-1)
        
        x = self.fc(x)
        return F.log_softmax(x,dim=1)
myNet = MyNet().to(device)

訓練集以及驗證集結果

大概經過300個epoch訓練,驗證集便能夠達到\(99.9\%\)以上的正確率。

訓練集的Loss曲線:

測試集結果

測試集使用訓練400個epoch之後的模型進行預測。其最終預測的正確率爲:\(99.36 \%\)。實際上,大概300個epoch就能夠在測試集達到\(99\%\)以上的正確率。

參考

  1. MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges
  2. MNIST — Torchvision 0.12 documentation (pytorch.org)
  3. python處理MNIST數據集 - 簡書 (jianshu.com)
  4. 訓練集(train)驗證集(validation)測試集(test)與交叉驗證法 - 知乎 (zhihu.com)
  5. sklearn.model_selection.train_test_split — scikit-learn 1.0.2 documentation
  6. Netron
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章