基於pytorch實現手寫數字識別(附python代碼)

/1加載圖片:加載數據集,沒有的話會自動下載,數據分佈在0附近,並打散。

訓練集:測試集=6k:1k。

utils.py文件:plot_image()繪製loss下降曲線; plot_curve()顯示圖片通過plot_image()可視化結果。minst_train.py文件:讀取Minst數據集

/2 加載模型:三層線性模型,前兩層用ReLU函數,batch_size=512,一張圖片28*28,Normalize將數據均勻分佈。

/3 訓練:學習率0.01,momentum = 0.9,loss定義,梯度清零、計算、更新,每10次顯示loss,可以看到loss下降:

/4 測試

計算正確率並顯示梯度下降:

遇到的問題:pytorch中優化器獲得的是空參數表

ValueError:optimizer got an empty parameter list

解決:初始函數定義未正確,兩個下劃線

def __init__(self):

        super(Net, self).__init__()

win10+anaconda3+python3.7,安裝tensorflow、pytorch、opencv、CUDA10.2

mnist_train.py

# -*- coding: utf-8 -*-
"""
Created on Tue Jan 14 15:10:20 2020

@author: ZM
"""
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim

import torchvision
from matplotlib import pyplot as plt

from utils import plot_image, plot_curve, one_hot

batch_size=512
#step1:load dataset
#加載數據集,沒有的話會自動下載,數據分佈在0附近,並打散
train_loader=torch.utils.data.DataLoader(
   torchvision.datasets.MNIST('mnist_data',train=True,download=True,
                               transform=torchvision.transforms.Compose([
                                       torchvision.transforms.ToTensor(),
                                       torchvision.transforms.Normalize(
                                               (0.1307,),(0.3081,))
                                       ])),
    batch_size=batch_size,shuffle=True)
                                       
test_loader=torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data/',train=False,download=True,
                               transform=torchvision.transforms.Compose([
                                       torchvision.transforms.ToTensor(),
                                       torchvision.transforms.Normalize(
                                               (0.1307,),(0.3081,))
                                       ])),
    batch_size=batch_size,shuffle=False)
                                       
#顯示:batch_size=512,一張圖片28*28,Normalize將數據均勻
x, y = next(iter(train_loader))
print(x.shape,y.shape,x.min(),x.max())
plot_image(x, y, 'image sample')

#建立模型
class Net(nn.Module):
    
    def __init__(self):
        super(Net, self).__init__()
        
        #wx+b
        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64,10)
        
    def forward(self, x):
        #x:[b,1,28,28]
        #h1=relu(w1x+b1)
        x = F.relu(self.fc1(x))
        #h2=relu(h1w2+b2)
        x = F.relu(self.fc2(x))
        #h3=h2w3+b3
        x = self.fc3(x)
        
        return x
#        return F.log_softmax(x, dim=1)
#訓練    
net = Net()#初始化
#返回[w1,b1,w2,b2,w3,b3]
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum = 0.9)
train_loss = []

for epoch in range(3):
    for batch_idx, (x,y) in enumerate(train_loader):
        
#        x[b,1,28,28] y:[512]
#        print(x.shape,y.shape)
#        break
#        x, y = Variable(x), Variable(y)
        #[b,1,28,28]=>[b,784]實際圖片4維打平爲二維
    
        x = x.view(x.size(0), 28*28)
        #[b,10]
        out = net(x)
        #[b,10]
        y_onehot = one_hot(y)
        #loss=mse(out,y_onehot)
        loss = F.mse_loss(out, y_onehot)
        
        optimizer.zero_grad()
        loss.backward()
        #w'=w-li*grad
        optimizer.step()
        
#測試
        train_loss.append(loss.item())
        if batch_idx % 10==0:
            print(epoch, batch_idx, loss.item())
plot_curve(train_loss)
#達到較好的[w1,b1,w2,b2,w3,b3]
            
total_correct=0
for x,y in test_loader:
    x = x.view(x.size(0),28*28)  
    #out:[b,10] => pred:[b]     
    out = net(x)
     
    pred = out.argmax(dim = 1)
    correct = pred.eq(y).sum().float().item()
    total_correct += correct
     
total_num = len(test_loader.dataset)
acc = total_correct / total_num
print('test acc:', acc)

x,y = next(iter(test_loader))
out = net(x.view(x.size(0),28*28))
pred = out.argmax(dim = 1)
plot_image(x, pred, 'test')

utils.py 

# -*- coding: utf-8 -*-
"""
Created on Tue Jan 14 16:37:46 2020

@author: ZM
"""

import torch
from matplotlib import pyplot as plt

def plot_curve(data):
    fig = plt.figure()
    plt.plot(range(len(data)), data, color='blue')
    plt.legend(['value'], loc='upper right')
    plt.xlabel('step')
    plt.ylabel('value')
    plt.show()
    
def plot_image(img, label, name):
    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i+1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
        plt.title("{}: {}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()

def one_hot(label, depth=10):
    out = torch.zeros(label.size(0), depth)
    idx = torch.LongTensor(label).view(-1,1)
    out.scatter_(dim=1, index=idx, value=1)
    return out

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章