pytorch進行CIFAR-10分類(4)訓練
我的系列博文:
Pytorch打怪路(一)pytorch進行CIFAR-10分類(1)CIFAR-10數據加載和處理
Pytorch打怪路(一)pytorch進行CIFAR-10分類(2)定義卷積神經網絡
Pytorch打怪路(一)pytorch進行CIFAR-10分類(3)定義損失函數和優化器
Pytorch打怪路(一)pytorch進行CIFAR-10分類(4)訓練(本文)
Pytorch打怪路(一)pytorch進行CIFAR-10分類(5)測試
1、簡述
經過前面的數據加載和網絡定義後,就可以開始訓練了,這裏會看到前面遇到的一些東西究竟在後面會有什麼用,所以這一步希望各位也能仔細研究一下
2、代碼
for epoch in range(2): # loop over the dataset multiple times 指定訓練一共要循環幾個epoch
running_loss = 0.0 #定義一個變量方便我們對loss進行輸出
for i, data in enumerate(trainloader, 0): # 這裏我們遇到了第一步中出現的trailoader,代碼傳入數據
# enumerate是python的內置函數,既獲得索引也獲得數據,詳見下文
# get the inputs
inputs, labels = data # data是從enumerate返回的data,包含數據和標籤信息,分別賦值給inputs和labels
# wrap them in Variable
inputs, labels = Variable(inputs), Variable(labels) # 將數據轉換成Variable,第二步裏面我們已經引入這個模塊
# 所以這段程序裏面就直接使用了,下文會分析
# zero the parameter gradients
optimizer.zero_grad() # 要把梯度重新歸零,因爲反向傳播過程中梯度會累加上一次循環的梯度
# forward + backward + optimize
outputs = net(inputs) # 把數據輸進網絡net,這個net()在第二步的代碼最後一行我們已經定義了
loss = criterion(outputs, labels) # 計算損失值,criterion我們在第三步裏面定義了
loss.backward() # loss進行反向傳播,下文詳解
optimizer.step() # 當執行反向傳播之後,把優化器的參數進行更新,以便進行下一輪
# print statistics # 這幾行代碼不是必須的,爲了打印出loss方便我們看而已,不影響訓練過程
running_loss += loss.data[0] # 從下面一行代碼可以看出它是每循環0-1999共兩千次纔打印一次
if i % 2000 == 1999: # print every 2000 mini-batches 所以每個2000次之類先用running_loss進行累加
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 2000)) # 然後再除以2000,就得到這兩千次的平均損失值
running_loss = 0.0 # 這一個2000次結束後,就把running_loss歸零,下一個2000次繼續使用
print('Finished Training')
3、分析
①autograd
在第二步中我們定義網絡時定義了前向傳播函數,但是並沒有定義反向傳播函數,可是深度學習是需要反向傳播求導的,
Pytorch其實利用的是Autograd模塊來進行自動求導,反向傳播。
Autograd中最核心的類就是Variable了,它封裝了Tensor,並幾乎支持所有Tensor的操作,這裏可以參考官方給的詳細解釋:
以上鍊接詳細講述了variable究竟是怎麼能夠實現自動求導的,怎麼用它來實現反向傳播的。
這裏涉及到計算圖的相關概念,這裏我不詳細講,後面會寫相關博文來討論這個東西,暫時不會對我們理解這個程序造成影響
只說一句,想要計算各個variable的梯度,只需調用根節點的backward方法,Autograd就會自動沿着整個計算圖進行反向計算
而在此例子中,根節點就是我們的loss,所以:
程序中的loss.backward()代碼就是在實現反向傳播,自動計算所有的梯度。
所以訓練部分的代碼其實比較簡單:
running_loss和後面負責打印損失值的那部分並不是必須的,所以關鍵行不多,總得來說分成三小節
第一節:把最開始放在trainloader裏面的數據給轉換成variable,然後指定爲網絡的輸入;
第二節:每次循環新開始的時候,要確保梯度歸零
第三節:forward+backward,就是調用我們在第三步裏面實例化的net()實現前傳,loss.backward()實現後傳
每結束一次循環,要確保梯度更新