本文主要是用PyTorch來實現一個簡單的迴歸任務。
編輯器:spyder
1.引入相應的包及生成僞數據
import torch
import torch.nn.functional as F # 主要實現激活函數
import matplotlib.pyplot as plt # 繪圖的工具
from torch.autograd import Variable
# 生成僞數據
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim = 1)
y = x.pow(2) + 0.2 * torch.rand(x.size())
# 變爲Variable
x, y = Variable(x), Variable(y)
其中torch.linspace
是爲了生成連續間斷的數據,第一個參數表示起點,第二個參數表示終點,第三個參數表示將這個區間分成平均幾份,即生成幾個數據。因爲torch只能處理二維的數據,所以我們用torch.unsqueeze
給僞數據添加一個維度,dim表示添加在第幾維。torch.rand
返回的是[0,1)之間的均勻分佈。
2.繪製數據圖像
在上述代碼後面加下面的代碼,然後運行可得僞數據的圖形化表示:
# 繪製數據圖像
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()
3.建立神經網絡
class Net(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer
self.predict = torch.nn.Linear(n_hidden, n_output) # output layer
def forward(self, x):
x = F.relu(self.hidden(x)) # activation function for hidden layer
x = self.predict(x) # linear output
return x
net = Net(n_feature=1, n_hidden=10, n_output=1) # define the network
print(net) # net architecture
一般神經網絡的類都繼承自torch.nn.Module
,__init__()
和forward()
兩個函數是自定義類的主要函數。在__init__()
中都要添加一句super(Net, self).__init__()
,這是固定的標準寫法,用於繼承父類的初始化函數。__init__()
中只是對神經網絡的模塊進行了聲明,真正的搭建是在forwad()
中實現。自定義類中的成員都通過self指針來進行訪問,所以參數列表中都包含了self。
如果想查看網絡結構,可以用print()
函數直接打印網絡。本文的網絡結構輸出如下:
Net (
(hidden): Linear (1 -> 10)
(predict): Linear (10 -> 1)
)
4.訓練網絡
# 訓練100次
for t in range(100):
prediction = net(x) # input x and predict based on x
loss = loss_func(prediction, y) # 一定要是輸出在前,標籤在後 (1. nn output, 2. target)
optimizer.zero_grad() # clear gradients for next train
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients
訓練網絡之前我們需要先定義優化器和損失函數。torch.optim
包中包括了各種優化器,這裏我們選用最常見的SGD作爲優化器。因爲我們要對網絡的參數進行優化,所以我們要把網絡的參數net.parameters()
傳入優化器中,並設置學習率(一般小於1)。
由於這裏是迴歸任務,我們選擇torch.nn.MSELoss()
作爲損失函數。
由於優化器是基於梯度來優化參數的,並且梯度會保存在其中。所以在每次優化前要通過optimizer.zero_grad()
把梯度置零,然後再後向傳播及更新。
5.可視化訓練過程
plt.ion() # something about plotting
for t in range(100):
...
if t % 5 == 0:
# plot and show learning process
plt.cla()
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.text(0.5, 0, 'Loss=%.4f' % loss.data[0], fontdict={'size': 20, 'color': 'red'})
plt.pause(0.1)
plt.ioff()
plt.show()