pytorch 手寫數字識別

import torch as t
from matplotlib import pyplot as plt
def plot_curve(data):
    fig=plt.figure()
    plt.plot(range(len(data)),data,color='blue')
    plt.legend(['value'],loc='upper right')
    plt.xlabel('step')
    plt.ylabel('value')
    plt.show()
def plot_image(img,label,name):
    
    fig=plt.figure()
    for i in range(6):
        plt.subplot(2,3,i+1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307,cmap='gray',interpolation='none')
        plt.title("{}:{}".format(name,label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()
def one_hot(label,depth=10):
    out=t.zeros(label.size(0),depth)
    idx=t.LongTensor(label).view(-1,1)
    out.scatter_(dim=1,index=idx,value=1)
    return out
from torch import nn
from torch.nn import functional as F
from torch import optim

import torchvision
batch_size=512

# 1 load dataset
train_loader=t.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data',train=True,download=True,
                              transform=torchvision.transforms.Compose([
                                  torchvision.transforms.ToTensor(),
                                  torchvision.transforms.Normalize(
                                      (0.1307,),(0.3081,))
                              ])),
    batch_size=batch_size,shuffle=True)

test_loader=t.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data/',train=False,download=True,
                              transform=torchvision.transforms.Compose([
                                  torchvision.transforms.ToTensor(),
                                  torchvision.transforms.Normalize(
                                      (0.1307,),(0.3081,))
                              ])),
    batch_size=batch_size,shuffle=False)
x,y=next(iter(train_loader))
print(x.shape,y.shape,x.min(),x.max())
plot_image(x,y,'image sample')
torch.Size([512, 1, 28, 28]) torch.Size([512]) tensor(-0.4242) tensor(2.8215)

在這裏插入圖片描述

class Net(nn.Module):
    
    def __init__(self):
        super(Net,self).__init__()
        
        # xw+b
        self.fc1=nn.Linear(28*28,256)
        self.fc2=nn.Linear(256,64)
        self.fc3=nn.Linear(64,10)
        
    def forward(self,x):
        # x: [b,1,28,28]
        # h1=relu(xw1+b1)
        x=F.relu(self.fc1(x))
        # h2=relu(h1w2+b2)
        x=F.relu(self.fc2(x))
        # h3=h2w3+b3
        x=self.fc3(x)
        
        return x
net=Net()
optimizer=optim.SGD(net.parameters(),lr=0.01,momentum=0.9)

train_loss=[]

for epoch in range(3):
    
    for batch_idx,(x,y) in enumerate(train_loader):
        
        # x: [b,1,28,28], y:[512]
        # [b,1,28,28]=>[b,feature]
        x=x.view(x.size(0),28*28)
        # => [b,10]
        out=net(x)
        #[b,10]
        y_onehot=one_hot(y)
        loss=F.mse_loss(out,y_onehot)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss.append(loss.item())
        
        if batch_idx%10==0:
            print(epoch,batch_idx,loss.item())
plot_curve(train_loss)
0 0 0.11438305675983429
0 10 0.09387794137001038
0 20 0.08220846951007843
0 30 0.07706790417432785
0 40 0.0702405497431755
0 50 0.06567282974720001
0 60 0.06183426454663277
0 70 0.0587509460747242
0 80 0.05580664798617363
0 90 0.05178285390138626
0 100 0.05214530974626541
0 110 0.04988808557391167
1 0 0.047764163464307785
1 10 0.04818789288401604
1 20 0.04520576819777489
1 30 0.04330906271934509
1 40 0.04346104711294174
1 50 0.04184712469577789
1 60 0.042132407426834106
1 70 0.04193287342786789
1 80 0.04033951088786125
1 90 0.03958696871995926
1 100 0.03668265417218208
1 110 0.03986677899956703
2 0 0.04010584205389023
2 10 0.037799276411533356
2 20 0.034503210335969925
2 30 0.03469271585345268
2 40 0.03576362878084183
2 50 0.0365145206451416
2 60 0.03551790118217468
2 70 0.03547125309705734
2 80 0.03524373471736908
2 90 0.03232691437005997
2 100 0.032091040164232254
2 110 0.0334344319999218

在這裏插入圖片描述

total_correct=0
for x,y in test_loader:
    x=x.view(x.size(0),28*28)
    out=net(x)
    pred=out.argmax(dim=1)
    correct=pred.eq(y).sum().float().item()
    total_correct+=correct
    
total_num=len(test_loader.dataset)
acc=total_correct/total_num
print('test acc:',acc)
test acc: 0.8902
x,y=next(iter(test_loader))
out=net(x.view(x.size(0),28*28))
pred=out.argmax(dim=1)
plot_image(x,pred,'test')

在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章