MNIST Handwritten Digit Recognition in PyTorch (Nerual Network)

在這裏插入圖片描述

#!/usr/bin/env python
# coding: utf-8

# In[1]:


import torch
import torchvision
import matplotlib.pyplot as plt


# In[2]:


# Set parameters
batchSize = 64
learningRate = 0.1
epochNum = 10

# Download MNIST dataset
trainDataset = torchvision.datasets.MNIST('./data', train = True, transform = torchvision.transforms.ToTensor(), download = True)
valDataset = torchvision.datasets.MNIST('./data', train = False, transform = torchvision.transforms.ToTensor(), download = True)

# Loading data
trainData = torch.utils.data.DataLoader(trainDataset, batch_size = batchSize, shuffle = True, drop_last = True)
valData = torch.utils.data.DataLoader(valDataset, batch_size = batchSize, shuffle = True, drop_last = True)


# In[3]:


net = torch.nn.Sequential(
    torch.nn.Linear(784, 1024),
    torch.nn.ReLU(),
    torch.nn.Linear(1024, 300),
    torch.nn.ReLU(),
    torch.nn.Linear(300, 10),
#     torch.nn.Softmax(dim = 1)
)


# In[4]:


criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr = learningRate)


# In[5]:


losses = []
acces = []
valLosses = []
valAcces = []
for epoch in range(epochNum):
    print(epoch)
    for idx, (img, lbl) in enumerate(trainData):
        net.train()
        img = img.reshape((batchSize, -1))
        try:
            out = net(img)
            loss = criterion(out, lbl)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            

            _, result = out.max(1)
            acc = (result == lbl).sum().item()/img.shape[0] * 100
            
            losses.append(loss.item())
            acces.append(acc)
            
            ###############################
            net.eval()
            img, lbl = next(iter(valData))
            img = img.reshape((batchSize, -1))
            
            out = net(img)
            loss = criterion(out, lbl)

            _, result = out.max(1)
            acc = (result == lbl).sum().item()/img.shape[0] * 100
                
            valLosses.append(loss.item())
            valAcces.append(acc)
            ###########################
            
        except Exception as exc:
            print(exc)
            
    plt.plot(acces)
    plt.plot(valAcces)
    plt.show()
    
    plt.plot(losses)
    plt.plot(valLosses)
    plt.show()


# In[ ]:
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章