import torch as t
from matplotlib import pyplot as plt
def plot_curve(data):
fig=plt.figure()
plt.plot(range(len(data)),data,color='blue')
plt.legend(['value'],loc='upper right')
plt.xlabel('step')
plt.ylabel('value')
plt.show()
def plot_image(img,label,name):
fig=plt.figure()
for i in range(6):
plt.subplot(2,3,i+1)
plt.tight_layout()
plt.imshow(img[i][0]*0.3081+0.1307,cmap='gray',interpolation='none')
plt.title("{}:{}".format(name,label[i].item()))
plt.xticks([])
plt.yticks([])
plt.show()
def one_hot(label,depth=10):
out=t.zeros(label.size(0),depth)
idx=t.LongTensor(label).view(-1,1)
out.scatter_(dim=1,index=idx,value=1)
return out
from torch import nn
from torch.nn import functional as F
from torch import optim
import torchvision
batch_size=512
train_loader=t.utils.data.DataLoader(
torchvision.datasets.MNIST('mnist_data',train=True,download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,),(0.3081,))
])),
batch_size=batch_size,shuffle=True)
test_loader=t.utils.data.DataLoader(
torchvision.datasets.MNIST('mnist_data/',train=False,download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,),(0.3081,))
])),
batch_size=batch_size,shuffle=False)
x,y=next(iter(train_loader))
print(x.shape,y.shape,x.min(),x.max())
plot_image(x,y,'image sample')
torch.Size([512, 1, 28, 28]) torch.Size([512]) tensor(-0.4242) tensor(2.8215)
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.fc1=nn.Linear(28*28,256)
self.fc2=nn.Linear(256,64)
self.fc3=nn.Linear(64,10)
def forward(self,x):
x=F.relu(self.fc1(x))
x=F.relu(self.fc2(x))
x=self.fc3(x)
return x
net=Net()
optimizer=optim.SGD(net.parameters(),lr=0.01,momentum=0.9)
train_loss=[]
for epoch in range(3):
for batch_idx,(x,y) in enumerate(train_loader):
x=x.view(x.size(0),28*28)
out=net(x)
y_onehot=one_hot(y)
loss=F.mse_loss(out,y_onehot)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss.append(loss.item())
if batch_idx%10==0:
print(epoch,batch_idx,loss.item())
plot_curve(train_loss)
0 0 0.11438305675983429
0 10 0.09387794137001038
0 20 0.08220846951007843
0 30 0.07706790417432785
0 40 0.0702405497431755
0 50 0.06567282974720001
0 60 0.06183426454663277
0 70 0.0587509460747242
0 80 0.05580664798617363
0 90 0.05178285390138626
0 100 0.05214530974626541
0 110 0.04988808557391167
1 0 0.047764163464307785
1 10 0.04818789288401604
1 20 0.04520576819777489
1 30 0.04330906271934509
1 40 0.04346104711294174
1 50 0.04184712469577789
1 60 0.042132407426834106
1 70 0.04193287342786789
1 80 0.04033951088786125
1 90 0.03958696871995926
1 100 0.03668265417218208
1 110 0.03986677899956703
2 0 0.04010584205389023
2 10 0.037799276411533356
2 20 0.034503210335969925
2 30 0.03469271585345268
2 40 0.03576362878084183
2 50 0.0365145206451416
2 60 0.03551790118217468
2 70 0.03547125309705734
2 80 0.03524373471736908
2 90 0.03232691437005997
2 100 0.032091040164232254
2 110 0.0334344319999218
total_correct=0
for x,y in test_loader:
x=x.view(x.size(0),28*28)
out=net(x)
pred=out.argmax(dim=1)
correct=pred.eq(y).sum().float().item()
total_correct+=correct
total_num=len(test_loader.dataset)
acc=total_correct/total_num
print('test acc:',acc)
test acc: 0.8902
x,y=next(iter(test_loader))
out=net(x.view(x.size(0),28*28))
pred=out.argmax(dim=1)
plot_image(x,pred,'test')