import torch
import torchvision#包含一些數據集,像如mnist
import torch.utils.data as Data
import torch.nn as nn
import matplotlib.pyplot as plt
#超參數
EPOCH=1
BATCH_SIZE=50
LR=0.01
DOWNLOAD_MNIST=True#是否下載數據集
#mnist手寫數字數據集
train_data=torchvision.datasets.MNIST(
root="mnist",
train=True,#是否是訓練
transform=torchvision.transforms.ToTensor(),#將圖像或numpy數組數據形式轉化成tensor的floattensor(C*H*M)形式
download=DOWNLOAD_MNIST #初始下載數據集用,後面就不用了
)
#測試集
test_data=torchvision.datasets.MNIST("mnist",train=False)
#批訓練和dataloader數據封裝
train_loader=Data.DataLoader(
dataset=train_data,
batch_size=BATCH_SIZE,
shuffle=True
)
#測試數據 只取前2000個
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000]/255. # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
test_y = test_data.test_labels[:2000]
#實現卷積層模型
class CNN(torch.nn.Module):
def __init__(self):
super(CNN,self).__init__()
#第一層卷積
self.conv1=torch.nn.Sequential(
torch.nn.Conv2d(
in_channels=1,
out_channels=16,
kernel_size=5,
stride=1,
padding=2#padding=(kernel_size-1)/2時,卷積出來的圖片尺寸不變
),#此時輸入是(,28*28)輸出是(16*28*28)
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2)#此時圖片尺寸減小一半 16*28*28--》16*14*14
)
#第二層卷積
self.conv2=torch.nn.Sequential(
torch.nn.Conv2d(16,32,5,1,2),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2)
)#輸入是16*14*14 輸出是32*7*7
#添加全連接層
self.output=torch.nn.Linear(32*7*7,10)
#前向傳播 卷積
def forward(self,x):
x=self.conv1(x)
x=self.conv2(x)
#將此時的32*7*7的圖片鋪平 拉成一維32*7*7的數組形式
x=x.view(x.size(0),-1)# 展平多維的卷積圖成 (batch_size, 32 * 7 * 7)
output=self.output(x)
return output
#實例化卷積神經網絡
cnn=CNN()
print(cnn)#打印網絡結構
#訓練
#優化器
optimizer=torch.optim.Adam(cnn.parameters(),lr=LR)
#損失函數
loss_func=torch.nn.CrossEntropyLoss()#用交叉熵損失函數 標籤是one-hot編碼
#訓練和測試
for epoch in range(EPOCH):
#批訓練
for step,(batch_x,batch_y) in enumerate(train_loader):
output=cnn(batch_x)
loss=loss_func(output,batch_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
#預測
test_output = cnn(test_x[:10])
pred_y = torch.max(test_output,1)[1].data.numpy().squeeze()
print(pred_y, 'prediction number')
print(test_y[:10].numpy(), 'real number')