pytorch loss 出現nan,原因之一,計算圖中存在torch.sqrt或者 **0.5,以及如何解決

今天寫一個loss函數

dist=torch.sqrt(x*x+y*y)
loss=soomthL1loss(dist,gt_dist)

我隨便寫的幾句示意代碼,這樣會導致在第一個iteration之後出現nan,第一次iteration之內,還是可以看到loss不爲nan的。

解決辦法:

     1、不開方,因爲開方的求導會出現在分母上,因此需要避免分母爲0!

     2、torch.sqrt(x*x+0.000001)增加一個 很小的 “一瞥西漏”

給大家個測試代碼:

import torch
a = torch.zeros(1,requireds_grad = True)
b = torch.sqrt(a)
b.backward()
print(a.grad)
#得到tensor([inf]),看到inf就知道,一般來說沒辦法傳遞了,爲什麼是一般來說,因爲用過darknet的yolo的話,裏面出現inf還是可以訓練的(可能我記錯了)


#修改下:
import torch
a = torch.zeros(1,requireds_grad = True)
b = torch.sqrt(a+0.001)
b.backward()
print(a.grad)
#tensor([15.814])

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章