PyTorch代碼識別手寫數字

完整的PyTorch代碼識別手寫數字

# -*- coding: utf-8 -*-
import torch
import torchvision
from torchvision import datasets, transforms
# 1. 加載MNIST手寫數字數據集數據和標籤
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, ), (0.5, ))])
trainset = datasets.MNIST(root='./data', train=True,
                            download=True, transform=transform)
trainsetloader = torch.utils.data.DataLoader(trainset, batch_size=20000, shuffle=True)

testset = datasets.MNIST(root='./data', train=True,
                            download=True, transform=transform)
testsetloader = torch.utils.data.DataLoader(testset, batch_size=20000, shuffle=True)

#######如果你不放心數據有沒有加載出可以將圖片顯示出來看下#######
# dataiter = iter(trainsetloader)
# images, labels = dataiter.next()
# import numpy as np
# import matplotlib.pyplot as plt
# plt.imshow(images[0].numpy().squeeze())
# plt.show()
# print(images.shape)
# print(labels.shape)
##########上面這段是顯示圖片的代碼#############


# 2. 設計網絡結構
first_in, first_out, second_out = 28*28,  128, 10
model = torch.nn.Sequential(
    torch.nn.Linear(first_in, first_out),
    torch.nn.ReLU(),
    torch.nn.Linear(first_out, second_out),
)

# 3. 設計損失函數
loss_fn = torch.nn.CrossEntropyLoss()

# 4. 設置用於自動調節神經網絡參數的優化器
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# 5. 訓練神經網絡(重複訓練10次)
for t in range(10):
    for i, one_batch in enumerate(trainsetloader,0):
        data,label = one_batch
        data[0].view(1,784)# 將28x28的圖片變成784的向量
        data = data.view(data.shape[0],-1)

        # 讓神經網絡根據現有的參數,根據當前的輸入計算一個輸出
        model_output = model(data)
        # 5.1 用所設計算損失(誤差)函數計算誤差
        loss = loss_fn(model_output , label)
        if i%500 == 0:
            print(loss)
        # 5.2 每次訓練前清零之前計算的梯度(導數)
        optimizer.zero_grad()
        # 5.3 根據誤差反向傳播計算誤差對各個權重的導數
        loss.backward()
        # 5.4 根據優化器裏面的算法自動調整神經網絡權重
        optimizer.step()

# 保存下訓練好的模型,省得下次再重新訓練
torch.save(model,'./my_handwrite_recognize_model.pt')
    

##########現在你已經訓練好了#################
# 6. 用這個神經網絡解決你的問題,比如手寫數字識別,輸入一個圖片矩陣,然後模型返回一個數字
testdataiter = iter(testsetloader)
testimages, testlabels = testdataiter.next()

img_vector = testimages[0].squeeze().view(1,-1)
# 模型返回的是一個1x10的矩陣,第幾個元素值最大那就是表示模型認爲當前圖片是數字幾
result_digit = model(img_vector)
print("該手寫數字圖片識別結果爲:", result_digit.max(1)[1],"標籤爲:",testlabels[0])
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章