import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
#設置相關參數
input_size=28*28
num_classes=10
num_epochs=5
batch_size=100
learning_rate=0.001
#下載訓練數據集和測試數據集MNIST
train_dataset=torchvision.datasets.MNIST(root="./data",train=True,transform=transforms.ToTensor(),download=True)
test_dataset=torchvision.datasets.MNIST(root="./data",train=False,transform=transforms.ToTensor())
#數據加載
train_loader=torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)
test_loader=torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False)
#定義model、損失函數與優化函數
model=nn.Linear(input_size,num_classes)
criterion=nn.CrossEntropyLoss()
optimizer=torch.optim.SGD(model.parameters(),lr=learning_rate)
#開始迭代訓練
total_step=len(train_loader)
for epoch in range(num_epochs):
for i,(images,labels) in enumerate(train_loader):
images=images.reshape(-1,input_size)
outputs=model(images)
loss=criterion(outputs,labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) %100==0:
print("Epoch [{}/{}], Step [{}/{}], Loss:{:.4f}".format(epoch+1,num_epochs,i+1,total_step,loss.item()))
#訓練結束,利用測試數據集進行測試
with torch.no_grad():
correct=0
total=0
for images,labels in test_loader:
images=images.reshape(-1,input_size)
outputs=model(images)
_,predicted=torch.max(outputs.data,1)
total=total+labels.size(0)
correct=correct+(predicted==labels).sum()
print("Accuracy of the model on the 10000 test images:{} %".format(100*correct/total))
#保存模型
torch.save(model.state_dict(),"logisticModel.ckpt")
運行的時候會先下載相關數據集,隨後開始訓練建模和測試模型的準確率。最終結果如下:
Epoch [4/5], Step [100/600], Loss:1.2736
Epoch [4/5], Step [200/600], Loss:1.1791
Epoch [4/5], Step [300/600], Loss:1.2005
Epoch [4/5], Step [400/600], Loss:1.2209
Epoch [4/5], Step [500/600], Loss:1.0825
Epoch [4/5], Step [600/600], Loss:1.0868
Epoch [5/5], Step [100/600], Loss:1.2321
Epoch [5/5], Step [200/600], Loss:0.9939
Epoch [5/5], Step [300/600], Loss:1.0931
Epoch [5/5], Step [400/600], Loss:1.1036
Epoch [5/5], Step [500/600], Loss:1.0602
Epoch [5/5], Step [600/600], Loss:0.9724
Accuracy of the model on the 10000 test images:82 %
推薦一個公衆號:健哥聊量化,會持續推出股票相關基礎知識,以及python實現的一些基本的分析代碼。歡迎大家關注,二維碼如下: