寫在前面: 我是「虐貓人薛定諤i」,一個不滿足於現狀,有夢想,有追求的00後
本博客主要記錄和分享自己畢生所學的知識,歡迎關注,第一時間獲取更新。
不忘初心,方得始終。
❤❤❤❤❤❤❤❤❤❤
數據介紹
數據集用的是MNIST,這個應該是比較經典的數據集了,其中的手寫體數字識別,可以說是人工智能領域的HelloWorld了
設計思路
在網絡結構上,使用了兩個卷積層,兩個全連接層,使用ReLU函數作爲激活函數。
整個流程分爲三大部分:
1、數據加載及預處理
2、網絡結構的定義
3、訓練模型,進行測試
代碼
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets
from torchvision import transforms
from torch.autograd import Variable
batch_size = 64
epoch = 5
LR = 0.001
# 獲取手寫數字的訓練集和測試集
train_dataset = datasets.MNIST(root='./res/data',
transform=transforms.ToTensor(),
train=True,
download=True)
test_dataset = datasets.MNIST(root='./res/data',
transform=transforms.ToTensor(),
download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=True)
# 定義網絡結構
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2), nn.ReLU(),
nn.MaxPool2d(2, 2))
self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),
nn.MaxPool2d(2, 2))
self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
nn.BatchNorm1d(120), nn.ReLU())
self.fc2 = nn.Sequential(nn.Linear(120, 84), nn.BatchNorm1d(84),
nn.ReLU(), nn.Linear(84, 10))
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size()[0], -1)
x = self.fc1(x)
x = self.fc2(x)
return x
# 訓練模型
net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=LR)
for epoch in range(epoch):
sum_loss = 0.0
for i, data in enumerate(train_loader):
inputs, labels = data
inputs, labels = Variable(inputs), Variable(labels)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
sum_loss += loss.item()
if i % 100 == 99:
print("[%d, %d] loss is %.03f" %
(epoch + 1, i + 1, sum_loss / 100))
sum_loss = 0.0
# 測試模型
net.eval()
correct = 0
for data_test in test_loader:
images, labels = data_test
images, labels = Variable(images), Variable(labels)
output_test = net(images)
_, pred = torch.max(output_test, 1)
correct += (pred == labels).sum()
print("Test acc is {}".format(correct.item()/len(test_dataset)))
結果
(pytorch) PS D:\Code\Python> & C:/Users/Martin/anaconda3/envs/pytorch/python.exe d:/Code/Python/hwdrByPyTorch.py
[1, 100] loss is 0.705
[1, 200] loss is 0.191
[1, 300] loss is 0.129
[1, 400] loss is 0.111
[1, 500] loss is 0.097
[1, 600] loss is 0.084
[1, 700] loss is 0.070
[1, 800] loss is 0.070
[1, 900] loss is 0.071
Test acc is 0.9866166666666667
[2, 100] loss is 0.090
[2, 200] loss is 0.078
[2, 300] loss is 0.061
[2, 400] loss is 0.068
[2, 500] loss is 0.055
[2, 600] loss is 0.046
[2, 700] loss is 0.045
[2, 800] loss is 0.048
[2, 900] loss is 0.051
Test acc is 0.9852166666666666
[3, 100] loss is 0.036
[3, 200] loss is 0.041
[3, 300] loss is 0.037
[3, 400] loss is 0.036
[3, 500] loss is 0.033
[3, 600] loss is 0.035
[3, 700] loss is 0.039
[3, 800] loss is 0.044
[3, 900] loss is 0.039
Test acc is 0.9891
[4, 100] loss is 0.029
[4, 200] loss is 0.021
[4, 300] loss is 0.030
[4, 400] loss is 0.023
[4, 500] loss is 0.028
[4, 600] loss is 0.031
[4, 700] loss is 0.027
[4, 800] loss is 0.035
[4, 900] loss is 0.039
Test acc is 0.9929666666666667
[5, 100] loss is 0.022
[5, 200] loss is 0.025
[5, 300] loss is 0.021
[5, 400] loss is 0.024
[5, 500] loss is 0.025
[5, 600] loss is 0.029
[5, 700] loss is 0.027
[5, 800] loss is 0.032
[5, 900] loss is 0.029
Test acc is 0.9923666666666666
(pytorch) PS D:\Code\Python>
從結果中可以看到,在測試集上模型的準確率達到了98%以上。
總結
1、可以考慮在訓練結束後,將模型保存下來,以便後續使用
2、代碼中可以加入斷點續訓,避免由於訓練中斷,導致繼續訓練還要從頭開始