0. 遇到大坑
筆者在最近的項目中用到了自定義loss函數,代碼一切都準備就緒後,在訓練時遇到了梯度爆炸的問題,每次訓練幾個step後,梯度/loss都會變爲nan。一般情況下,梯度變爲nan都是出現了log(0), x/0等情況,導致結果變爲+inf,也就成了nan。
1. 問題分析
筆者需要的loss函數如下:
其中,, 。
從理論上分析,這個loss函數在反向傳播過程中很可能會遇到梯度爆炸,這是爲什麼呢?反向傳播的過程是對loss鏈式求一階導數的過程,那麼,的導數爲:
由於,這個導數又可以表示爲:
這樣的話,出現了類似於的表達式,也就會出現典型的問題了。爲了避免這個問題,首先進行了如下的改變:
經過改變,在時,不再是問題了,而是轉換爲了一個線性函數,梯度成爲了恆定的12.9,從理論上來看,避免了梯度爆炸的問題。
2. PyTorch初步實現
在實現這一過程時,依舊…遇到了大坑,下面通過示例代碼來說明:
"""
loss = mse(X, gamma_inv(X))
"""
def loss_function(x):
mask = (x < 0.003).float()
gamma_x = mask * 12.9 * x + (1-mask) * (x ** 0.5)
loss = torch.mean((x - gamma_x) ** 2)
return loss
if __name__ == '__main__':
x = Variable(torch.FloatTensor([0, 0.0025, 0.5, 0.8, 1]), requires_grad=True)
loss = loss_function(x)
print('loss:', loss)
loss.backward()
print(x.grad)
改進後的是一個分支結構,在實現時,就採用了類似於Matlab中矩陣計算的mask方式,mask定義爲,滿足條件的在mask中對應位置的值爲1,因此,mask * 12.9 * x
的結構只會保留的結果,同樣的道理,gamma_x = mask * 12.9 * x + (1-mask) * (x ** 0.5)
就實現了上述改進後的公式。
按理來說,此時,在反向傳播過程中的梯度應該是正確的,但是,上面代碼的輸出結果爲:
loss: tensor(0.0105, grad_fn=<MeanBackward1>)
tensor([ nan, 0.1416, -0.0243, -0.0167, 0.0000])
emmm…依舊爲nan,問題在理論層面得到了解決,但是,在實現層面依舊沒能解決…
3. 源碼調試分析
上面源碼的問題依舊在的實現,這個過程,在Python解釋器解釋的過程或許是這樣的:
- 計算
mask * 12.9
,對mask進行廣播式的乘法,結果爲:原本爲1的位置變爲了12.9,原本爲0的位置依舊爲0; - 將1.的結果繼續與x相乘,本質上仍然是與x的每個元素相乘,只是mask中不滿足條件的位置爲0,表現出的結果是僅對滿足條件的進行了計算;
- 按照2.所述的原理,公式的後半部分也是同樣的計算過程,即,中的每個值依舊會進行的計算;
按照上述過程進行前向傳播,在反向傳播時,梯度不是從某一個分支得到的,而是兩個分支的題目相加得到的,換句話說,依舊沒能解決梯度變爲nan的問題。
4. 源碼改進及問題解決
經過第三部分的分析,知道了梯度變爲nan的根本原因是當時依舊參與了的計算,導致在反向傳播時計算出的梯度爲nan。
要解決這個問題,就要保證在時不會進行這樣的計算。
新的PyTorch代碼如下:
def loss_function(x):
mask = x < 0.003
gamma_x = torch.FloatTensor(x.size()).type_as(x)
gamma_x[mask] = 12.9 * x[mask]
mask = x >= 0.003
gamma_x[mask] = x[mask] ** 0.5
loss = torch.mean((x - gamma_x) ** 2)
return loss
if __name__ == '__main__':
x = Variable(torch.FloatTensor([0, 0.0025, 0.5, 0.8, 1]), requires_grad=True)
loss = loss_function(x)
print('loss:', loss)
loss.backward()
print(x.grad)
改變的地方位於loss_function
,改變了對於分支的處理方式,控制並保住每次計算僅有滿足條件的值可以參與。此時輸出爲:
loss: tensor(0.0105, grad_fn=<MeanBackward1>)
tensor([ 0.0000, 0.1416, -0.0243, -0.0167, 0.0000])
就此,問題解決!
*原創博客,轉載請附加本文鏈接。