pytorch系列(六):各種優化器的性能比較

import torch
import torch.utils.data as Data
import torch.nn.functional as f
import matplotlib.pyplot as plt

#指定超參數
LR=0.01#學習率
BATCH_SIZE=32#批數據的大小
EPOCH=12#迭代次數

#構造數據集
x=torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y=x.pow(2)+0.1*(torch.normal(torch.zeros(*x.size())))


#打印數據
plt.scatter(x.data.numpy(),y.data.numpy(),c='r')
plt.show()

#使用dataloader工具進行數據的處理
torch_dataset=Data.TensorDataset(x,y)#將x和y轉換成torch可識別的數據集
loader=Data.DataLoader(
    dataset=torch_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2
)

#構造網絡結構併爲每一個優化器優化一個神經網絡
class Net(torch.nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.hidden=torch.nn.Linear(1,20)
        self.output=torch.nn.Linear(20,1)
    #前向傳播
    def forward(self,x):
        x=f.relu(self.hidden(x))
        x=self.output(x)
        return x

#每一個優化器對應一個網絡結構
net_SGD=Net()
net_Momentum=Net()
net_RMSprop=Net()
net_Adam=Net()

#放到一個列表中
nets=[net_SGD,net_Momentum,net_RMSprop,net_Adam]


#API化每一個優化器
opt_SGD=torch.optim.SGD(net_SGD.parameters(),lr=LR)
opt_Momentum=torch.optim.SGD(net_Momentum.parameters(),lr=LR,momentum=0.8)
opt_RMSprop=torch.optim.RMSprop(net_RMSprop.parameters(),lr=LR,alpha=0.9)
opt_Adam=torch.optim.Adam(net_Adam.parameters(),lr=LR,betas=(0.9,0.99))

#用一個列表存放每一個優化器
optimizers=[opt_SGD,opt_Momentum,opt_RMSprop,opt_Adam]
#指定損失函數
loss_func=torch.nn.MSELoss()
#用一個兩層列表記錄各個優化器的loss
loss_his=[[],[],[],[]]

#訓練  可視化
for epoch in range(EPOCH):
    print(epoch)
    for step,(batch_x,batch_y) in enumerate(loader):
        
        #對於每一個優化器,優化他的神經網絡
        for net,opt,l_his in zip(nets,optimizers,loss_his):
            output=net(batch_x)#對每一個網絡丟入數據
            loss=loss_func(output,batch_y)#計算預測值和真實值之間的誤差
            opt.zero_grad()#梯度清零
            loss.backward()#反向傳播
            opt.step()#更新每一個參數
            l_his.append(loss.data.numpy())

#可視化
lables=["SGD","Momentum","RMSprop","Adam"]
for i,l_his in enumerate(loss_his):#enumerate是列舉,會迭代列表的中的每一個索引和每一項的值
    plt.plot(l_his,label=lables[i])
plt.legend(loc=1)#legend是做一個圖例說明  loc=1表示放在右邊  詳情看參數,label=lables[i]相對應
plt.xlabel("steps")
plt.ylabel("loss")
plt.ylim((0,0.5))
plt.show()

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章