【Pytorch梯度爆炸】梯度、loss在反向傳播過程中變爲nan解決方法

0. 遇到大坑

筆者在最近的項目中用到了自定義loss函數,代碼一切都準備就緒後,在訓練時遇到了梯度爆炸的問題,每次訓練幾個step後,梯度/loss都會變爲nan。一般情況下,梯度變爲nan都是出現了log(0), x/0等情況,導致結果變爲+inf,也就成了nan。

1. 問題分析

筆者需要的loss函數如下:
L=1Ni=0N1(xiΓ(xi))2\mathscr{L}=\frac{1}{N} \sum_{i=0}^{N-1}{\left(x_i - \Gamma(x_i)\right)^2}
其中,Γ(xi)=xiγ\Gamma(x_i)=x_i^\gamma, 0<γ<10<\gamma<1

從理論上分析,這個loss函數在反向傳播過程中很可能會遇到梯度爆炸,這是爲什麼呢?反向傳播的過程是對loss鏈式求一階導數的過程,那麼,Γ(xi)\Gamma(x_i)的導數爲:
dΓ(xi)dxi=γxiγ1\frac{d\Gamma(x_i)}{dx_i}=\gamma x_i^{\gamma-1}
由於0<γ<10<\gamma<1,這個導數又可以表示爲:
dΓ(xi)dxi=γxi1γ\frac{d\Gamma(x_i)}{dx_i}=\frac{\gamma}{x_i^{1-\gamma}}
這樣的話,出現了類似於1/x1/x的表達式,也就會出現典型的0/10/1問題了。爲了避免這個問題,首先進行了如下的Γ(xi)\Gamma(x_i)改變:
Γ(xi)={12.9×xi,xi<0.003xiγ,xi0.003 \Gamma(x_i)=\left\{ \begin{aligned} 12.9 \times x_i, &x_i < 0.003\\ x_i^\gamma, & x_i \geq 0.003 \end{aligned} \right.
經過改變,在xi=0x_i=0時,不再是1/01/0問題了,而是轉換爲了一個線性函數,梯度成爲了恆定的12.9,從理論上來看,避免了梯度爆炸的問題。

2. PyTorch初步實現

在實現這一過程時,依舊…遇到了大坑,下面通過示例代碼來說明:

"""
loss = mse(X, gamma_inv(X))
"""
def loss_function(x):
    mask = (x < 0.003).float()
    gamma_x = mask * 12.9 * x + (1-mask) * (x ** 0.5)
    loss = torch.mean((x - gamma_x) ** 2)
    return loss

if __name__ == '__main__':
    x = Variable(torch.FloatTensor([0, 0.0025, 0.5, 0.8, 1]), requires_grad=True)
    loss = loss_function(x)
    print('loss:', loss)
    loss.backward()
    print(x.grad)

改進後的Γ(xi)\Gamma(x_i)是一個分支結構,在實現時,就採用了類似於Matlab中矩陣計算的mask方式,mask定義爲xi&lt;0.003x_i&lt;0.003,滿足條件的xix_i在mask中對應位置的值爲1,因此,mask * 12.9 * x的結構只會保留xi&lt;0.003x_i&lt;0.003的結果,同樣的道理,gamma_x = mask * 12.9 * x + (1-mask) * (x ** 0.5)就實現了上述改進後的Γ(xi)\Gamma(x_i)公式。

按理來說,此時,在反向傳播過程中的梯度應該是正確的,但是,上面代碼的輸出結果爲:

loss: tensor(0.0105, grad_fn=<MeanBackward1>)
tensor([    nan,  0.1416, -0.0243, -0.0167,  0.0000])

emmm…依舊爲nan,問題在理論層面得到了解決,但是,在實現層面依舊沒能解決…

3. 源碼調試分析

上面源碼的問題依舊在Γ(xi)\Gamma(x_i)的實現,這個過程,在Python解釋器解釋的過程或許是這樣的:

  1. 計算mask * 12.9,對mask進行廣播式的乘法,結果爲:原本爲1的位置變爲了12.9,原本爲0的位置依舊爲0;
  2. 將1.的結果繼續與x相乘,本質上仍然是與x的每個元素相乘,只是mask中不滿足條件的xix_i位置爲0,表現出的結果是僅對滿足條件的xix_i進行了計算;
  3. 按照2.所述的原理,Γ(xi)\Gamma(x_i)公式的後半部分也是同樣的計算過程,即,xx中的每個值依舊會進行xγx^\gamma的計算;

按照上述過程進行前向傳播,在反向傳播時,梯度不是從某一個分支得到的,而是兩個分支的題目相加得到的,換句話說,依舊沒能解決梯度變爲nan的問題。

4. 源碼改進及問題解決

經過第三部分的分析,知道了梯度變爲nan的根本原因是當xi=0x_i=0時依舊參與了xiγx_i^\gamma的計算,導致在反向傳播時計算出的梯度爲nan。

要解決這個問題,就要保證在xi=0x_i=0時不會進行這樣的計算。

新的PyTorch代碼如下:

def loss_function(x):
    mask = x < 0.003
    gamma_x = torch.FloatTensor(x.size()).type_as(x)
    gamma_x[mask] = 12.9 * x[mask]
    mask = x >= 0.003
    gamma_x[mask] = x[mask] ** 0.5
    loss = torch.mean((x - gamma_x) ** 2)
    return loss

if __name__ == '__main__':
    x = Variable(torch.FloatTensor([0, 0.0025, 0.5, 0.8, 1]), requires_grad=True)
    loss = loss_function(x)
    print('loss:', loss)
    loss.backward()
    print(x.grad)

改變的地方位於loss_function,改變了對於Γ(xi)\Gamma(x_i)分支的處理方式,控制並保住每次計算僅有滿足條件的值可以參與。此時輸出爲:

loss: tensor(0.0105, grad_fn=<MeanBackward1>)
tensor([ 0.0000,  0.1416, -0.0243, -0.0167,  0.0000])

就此,問題解決!

*原創博客,轉載請附加本文鏈接。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章