pytorch中model eval和torch no grad()的區別

model.eval()和with torch.no_grad()的區別

在PyTorch中進行validation時,會使用model.eval()切換到測試模式,在該模式下,

  • 主要用於通知dropout層和batchnorm層在train和val模式間切換
    • train模式下,dropout網絡層會按照設定的參數p設置保留激活單元的概率(保留概率=p); batchnorm層會繼續計算數據的mean和var等參數並更新。
    • val模式下,dropout層會讓所有的激活單元都通過,而batchnorm層會停止計算和更新mean和var,直接使用在訓練階段已經學出的mean和var值。
  • 該模式不會影響各層的gradient計算行爲,即gradient計算和存儲與training模式一樣,只是不進行反傳(backprobagation)
  • with torch.no_grad()則主要是用於停止autograd模塊的工作,以起到加速和節省顯存的作用,具體行爲就是停止gradient計算,從而節省了GPU算力和顯存,但是並不會影響dropout和batchnorm層的行爲。

使用場景

如果不在意顯存大小和計算時間的話,僅僅使用model.eval()已足夠得到正確的validation的結果;而with torch.zero_grad()則是更進一步加速和節省gpu空間(因爲不用計算和存儲gradient),從而可以更快計算,也可以跑更大的batch來測試。

參考

  • https://discuss.pytorch.org/t/model-eval-vs-with-torch-no-grad/19615/38
  • https://ryankresse.com/batchnorm-dropout-and-eval-in-pytorch/
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章