Pyorch 快速搭建手寫識別網絡 LeNet5

Pytorch 快速搭建手寫識別網絡

Pytorch是一個十分簡潔的神經網絡框架,可以快速的構建出需要的機器學習網絡。

Pytorch的安裝和簡介請看:https://blog.csdn.net/qq_33302004/article/details/106320649

本文以手寫識別問題爲基礎,搭建LeNet 5網絡,網絡結構如下圖:

與圖中不同的是我的數據集中輸入數據爲28*28(而不是32*32),所以第一個全連接層爲16*4*4——120。代碼如下:

#coding=utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision                      # 包含有支持加載類似Imagenet,CIFAR10,MNIST 等公共數據集的數據加載模塊 
from torchvision import datasets, transforms
from torch.autograd import Variable
import matplotlib.pyplot as plt
import numpy as np

# 超參數
batch_size = 100          # 批數量
learning_rate = 0.01     # 學習率
momentum = 0.5           # 衝量
TRAINING_STEPS = 10      # 訓練次數?


# 數據預處理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5,0.5,0.5)])

# 讀取數據
dataset_train = datasets.MNIST('MNIST/',train=True,transform = transforms.ToTensor())
dataset_test = datasets.MNIST('MNIST/',train=False,transform = transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(
    dataset = dataset_train,
    batch_size = batch_size,
    shuffle = True              # 將訓練模型的數據集進行打亂的操作
)
test_loader = torch.utils.data.DataLoader(
    dataset = dataset_test,
    batch_size = batch_size,
    shuffle = False
)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.conv1 = nn.Conv2d(1,6,5)       # 第一層: 輸入通道數爲1,輸出通道數爲6,卷積核5*5
        self.pool1 = nn.MaxPool2d(2)        # 第二層: 最大池化
        self.conv2 = nn.Conv2d(6,16,5)      # 第三層: 輸入通道數爲6,輸出通道數爲16,卷積核5*5
        self.pool2 = nn.MaxPool2d(2)        # 第四層: 最大池化
        self.fc1 = nn.Linear(16*4*4,120)    # 第五層: 全連接,輸入節點16*5*5,輸出節點120
        self.fc2 = nn.Linear(120,84)        # 第六層: 全連接,輸入節點120,輸出節點84
        self.fc3 = nn.Linear(84,10)         # 第七層: 全連接,輸入節點84,輸出節點10

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = x.view(-1, 16*4*4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        x = F.log_softmax(x,dim = 1)
        return x

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# 網絡實例化
net = Net()
# 構建優化器
optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum)  # weight decay 什麼意思

def do_train(epoch):
    net.train()
    for step, (x, y) in enumerate(train_loader):
        x, y = Variable(x), Variable(y)
        y_ = net(x)
        loss = F.nll_loss(y_, y)
        optimizer.zero_grad()
        loss.backward()
        # update
        optimizer.step()
        if step % 200 == 0:
            print('Train Epoch: ', epoch, ' [', step * len(x), '/', len(train_loader.dataset), ' (', 100. * step / len(train_loader), '%)]\tLoss: ', loss.item())

def do_test():
    net.eval()
    test_loss = 0
    correct = 0
    # 測試集
    with torch.no_grad():
        for x, y in test_loader:
            x, y = Variable(x, volatile=True), Variable(y)
            y_ = net(x)
            # sum up batch loss
            test_loss += F.nll_loss(y_, y).item()
            # get the index of the max
            pred = y_.data.max(1, keepdim=True)[1]
            correct += pred.eq(y.data.view_as(pred)).cpu().sum()
        test_loss /= len(test_loader.dataset)
        print('\nTest set: Average loss: ', test_loss, ', Accuracy: ', int(correct), '/', len(test_loader.dataset),' (', 100. * int(correct) / len(test_loader.dataset),'%)\n')


for epoch in range(1,TRAINING_STEPS):
    do_train(epoch)
    do_test()

運行輸出:

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章