用PyTorch來實現手寫體數字識別

寫在前面: 我是「虐貓人薛定諤i」,一個不滿足於現狀,有夢想,有追求的00後
\quad
本博客主要記錄和分享自己畢生所學的知識,歡迎關注,第一時間獲取更新。
\quad
不忘初心,方得始終。
\quad

❤❤❤❤❤❤❤❤❤❤

在這裏插入圖片描述

數據介紹

數據集用的是MNIST,這個應該是比較經典的數據集了,其中的手寫體數字識別,可以說是人工智能領域的HelloWorld了

設計思路

在網絡結構上,使用了兩個卷積層,兩個全連接層,使用ReLU函數作爲激活函數。

conv1
conv2
fc1
fc2

整個流程分爲三大部分:
1、數據加載及預處理
2、網絡結構的定義
3、訓練模型,進行測試

代碼

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets
from torchvision import transforms
from torch.autograd import Variable

batch_size = 64
epoch = 5
LR = 0.001
# 獲取手寫數字的訓練集和測試集
train_dataset = datasets.MNIST(root='./res/data',
                               transform=transforms.ToTensor(),
                               train=True,
                               download=True)
test_dataset = datasets.MNIST(root='./res/data',
                              transform=transforms.ToTensor(),
                              download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=True)


# 定義網絡結構
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2), nn.ReLU(),
                                   nn.MaxPool2d(2, 2))
        self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),
                                   nn.MaxPool2d(2, 2))
        self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
                                 nn.BatchNorm1d(120), nn.ReLU())
        self.fc2 = nn.Sequential(nn.Linear(120, 84), nn.BatchNorm1d(84),
                                 nn.ReLU(), nn.Linear(84, 10))

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x


# 訓練模型
net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=LR)
for epoch in range(epoch):
    sum_loss = 0.0
    for i, data in enumerate(train_loader):
        inputs, labels = data
        inputs, labels = Variable(inputs), Variable(labels)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        sum_loss += loss.item()
        if i % 100 == 99:
            print("[%d, %d] loss is %.03f" %
                  (epoch + 1, i + 1, sum_loss / 100))
            sum_loss = 0.0
    # 測試模型
    net.eval()
    correct = 0
    for data_test in test_loader:
        images, labels = data_test
        images, labels = Variable(images), Variable(labels)
        output_test = net(images)
        _, pred = torch.max(output_test, 1)
        correct += (pred == labels).sum()
    print("Test acc is {}".format(correct.item()/len(test_dataset)))

在這裏插入圖片描述

結果

(pytorch) PS D:\Code\Python> & C:/Users/Martin/anaconda3/envs/pytorch/python.exe d:/Code/Python/hwdrByPyTorch.py
[1, 100] loss is 0.705
[1, 200] loss is 0.191
[1, 300] loss is 0.129
[1, 400] loss is 0.111
[1, 500] loss is 0.097
[1, 600] loss is 0.084
[1, 700] loss is 0.070
[1, 800] loss is 0.070
[1, 900] loss is 0.071
Test acc is 0.9866166666666667
[2, 100] loss is 0.090
[2, 200] loss is 0.078
[2, 300] loss is 0.061
[2, 400] loss is 0.068
[2, 500] loss is 0.055
[2, 600] loss is 0.046
[2, 700] loss is 0.045
[2, 800] loss is 0.048
[2, 900] loss is 0.051
Test acc is 0.9852166666666666
[3, 100] loss is 0.036
[3, 200] loss is 0.041
[3, 300] loss is 0.037
[3, 400] loss is 0.036
[3, 500] loss is 0.033
[3, 600] loss is 0.035
[3, 700] loss is 0.039
[3, 800] loss is 0.044
[3, 900] loss is 0.039
Test acc is 0.9891
[4, 100] loss is 0.029
[4, 200] loss is 0.021
[4, 300] loss is 0.030
[4, 400] loss is 0.023
[4, 500] loss is 0.028
[4, 600] loss is 0.031
[4, 700] loss is 0.027
[4, 800] loss is 0.035
[4, 900] loss is 0.039
Test acc is 0.9929666666666667
[5, 100] loss is 0.022
[5, 200] loss is 0.025
[5, 300] loss is 0.021
[5, 400] loss is 0.024
[5, 500] loss is 0.025
[5, 600] loss is 0.029
[5, 700] loss is 0.027
[5, 800] loss is 0.032
[5, 900] loss is 0.029
Test acc is 0.9923666666666666
(pytorch) PS D:\Code\Python> 

在這裏插入圖片描述
從結果中可以看到,在測試集上模型的準確率達到了98%以上。

總結

1、可以考慮在訓練結束後,將模型保存下來,以便後續使用
2、代碼中可以加入斷點續訓,避免由於訓練中斷,導致繼續訓練還要從頭開始
在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章