torch.no_grad和驗證模式

1.requires_grad

requires_gradVariable變量的requires_grad的屬性默認爲False,若一個節點requires_grad被設置爲True,那麼所有依賴它的節點的requires_grad都爲True。

volatile=True是Variable的另一個重要的標識,它能夠將所有依賴它的節點全部設爲volatile=True,其優先級比requires_grad=True高。因而volatile=True的節點不會求導,即使requires_grad=True,也不會進行反向傳播,對於不需要反向傳播的情景(inference,測試階段推斷階段),該參數可以實現一定速度的提升,並節省一半的顯存,因爲其不需要保存梯度。
注意:該屬性已經在0.4版本中被移除了

2.with torch.no_grad()和@torch.no_grad()

torch.no_grad()是新版本pytorch中volatile的替代

>>> x = torch.tensor([1], requires_grad=True)
>>> with torch.no_grad():
...   y = x * 2
>>> y.requires_grad
False
>>> @torch.no_grad()
... def doubler(x):
...     return x * 2
>>> z = doubler(x)
>>> z.requires_grad
False

3.model.eval()和with torch.no_grad()區別

在PyTorch中進行validation時,會使用model.eval()切換到測試模式:

  • 1.主要用於通知dropout層和batchnorm層在train和val模式間切換
    在train模式下,dropout網絡層會按照設定的參數p設置保留激活單元的概率(保留概率=p); batchnorm層會繼續計算數據的mean和var等參數並更新。
    在val模式下,dropout層會讓所有的激活單元都通過,而batchnorm層會停止計算和更新mean和var,直接使用在訓練階段已經學出的mean和var值。
  • 2.該模式不會影響各層的gradient計算行爲,即gradient計算和存儲與training模式一樣,只是不進行反傳(backprobagation)
  • 3.with torch.zero_grad()則主要是用於停止autograd模塊的工作,以起到加速和節省顯存的作用,具體行爲就是停止gradient計算,從而節省了GPU算力和顯存,但是並不會影響dropout和batchnorm層的行爲。

如果不在意顯存大小和計算時間的話,僅僅使用model.eval()已足夠得到正確的validation的結果;而with torch.zero_grad()則是更進一步加速和節省gpu空間(因爲不用計算和存儲gradient),從而可以更快計算,也可以跑更大的batch來測試。

參考
1.torch.no_grad
2.torch.no_grad
3.pytorch中model.eval()和“with torch.no_grad()區別

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章