Pytorch 快速搭建手寫識別網絡
Pytorch是一個十分簡潔的神經網絡框架,可以快速的構建出需要的機器學習網絡。
Pytorch的安裝和簡介請看:https://blog.csdn.net/qq_33302004/article/details/106320649
本文以手寫識別問題爲基礎,搭建LeNet 5網絡,網絡結構如下圖:
與圖中不同的是我的數據集中輸入數據爲28*28(而不是32*32),所以第一個全連接層爲16*4*4——120。代碼如下:
#coding=utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision # 包含有支持加載類似Imagenet,CIFAR10,MNIST 等公共數據集的數據加載模塊
from torchvision import datasets, transforms
from torch.autograd import Variable
import matplotlib.pyplot as plt
import numpy as np
# 超參數
batch_size = 100 # 批數量
learning_rate = 0.01 # 學習率
momentum = 0.5 # 衝量
TRAINING_STEPS = 10 # 訓練次數?
# 數據預處理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5,0.5,0.5)])
# 讀取數據
dataset_train = datasets.MNIST('MNIST/',train=True,transform = transforms.ToTensor())
dataset_test = datasets.MNIST('MNIST/',train=False,transform = transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(
dataset = dataset_train,
batch_size = batch_size,
shuffle = True # 將訓練模型的數據集進行打亂的操作
)
test_loader = torch.utils.data.DataLoader(
dataset = dataset_test,
batch_size = batch_size,
shuffle = False
)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1,6,5) # 第一層: 輸入通道數爲1,輸出通道數爲6,卷積核5*5
self.pool1 = nn.MaxPool2d(2) # 第二層: 最大池化
self.conv2 = nn.Conv2d(6,16,5) # 第三層: 輸入通道數爲6,輸出通道數爲16,卷積核5*5
self.pool2 = nn.MaxPool2d(2) # 第四層: 最大池化
self.fc1 = nn.Linear(16*4*4,120) # 第五層: 全連接,輸入節點16*5*5,輸出節點120
self.fc2 = nn.Linear(120,84) # 第六層: 全連接,輸入節點120,輸出節點84
self.fc3 = nn.Linear(84,10) # 第七層: 全連接,輸入節點84,輸出節點10
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.pool1(x)
x = F.relu(self.conv2(x))
x = self.pool2(x)
x = x.view(-1, 16*4*4)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
x = F.log_softmax(x,dim = 1)
return x
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# 網絡實例化
net = Net()
# 構建優化器
optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum) # weight decay 什麼意思
def do_train(epoch):
net.train()
for step, (x, y) in enumerate(train_loader):
x, y = Variable(x), Variable(y)
y_ = net(x)
loss = F.nll_loss(y_, y)
optimizer.zero_grad()
loss.backward()
# update
optimizer.step()
if step % 200 == 0:
print('Train Epoch: ', epoch, ' [', step * len(x), '/', len(train_loader.dataset), ' (', 100. * step / len(train_loader), '%)]\tLoss: ', loss.item())
def do_test():
net.eval()
test_loss = 0
correct = 0
# 測試集
with torch.no_grad():
for x, y in test_loader:
x, y = Variable(x, volatile=True), Variable(y)
y_ = net(x)
# sum up batch loss
test_loss += F.nll_loss(y_, y).item()
# get the index of the max
pred = y_.data.max(1, keepdim=True)[1]
correct += pred.eq(y.data.view_as(pred)).cpu().sum()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: ', test_loss, ', Accuracy: ', int(correct), '/', len(test_loader.dataset),' (', 100. * int(correct) / len(test_loader.dataset),'%)\n')
for epoch in range(1,TRAINING_STEPS):
do_train(epoch)
do_test()
運行輸出: