關於LSTM解決梯度彌散爆炸問題解析

轉自知乎@Towser 原鏈接



 

“LSTM 能解決梯度消失/梯度爆炸”是對 LSTM 的經典誤解。這裏我先給出幾個粗線條的結論,詳細的回答以後有時間了再擴展:

1、首先需要明確的是,RNN 中的梯度消失/梯度爆炸和普通的 MLP 或者深層 CNN 中梯度消失/梯度爆炸的含義不一樣。MLP/CNN 中不同的層有不同的參數,各是各的梯度;而 RNN 中同樣的權重在各個時間步共享,最終的梯度 g = 各個時間步的梯度 g_t 的和。

2、由 1 中所述的原因,RNN 中總的梯度是不會消失的。即便梯度越傳越弱,那也只是遠距離的梯度消失,由於近距離的梯度不會消失,所有梯度之和便不會消失。RNN 所謂梯度消失的真正含義是,梯度被近距離梯度主導,導致模型難以學到遠距離的依賴關係。

3、LSTM 中梯度的傳播有很多條路徑這條路徑上只有逐元素相乘和相加的操作,梯度流最穩定;但是其他路徑(例如 )上梯度流與普通 RNN 類似,照樣會發生相同的權重矩陣反覆連乘。

4、LSTM 剛提出時沒有遺忘門,或者說相當於 ,這時候在 直接相連的短路路徑上, 可以無損地傳遞給 ,從而這條路徑上的梯度暢通無阻,不會消失。類似於 ResNet 中的殘差連接。

5、但是在其他路徑上,LSTM 的梯度流和普通 RNN 沒有太大區別,依然會爆炸或者消失。由於總的遠距離梯度 = 各條路徑的遠距離梯度之和,即便其他遠距離路徑梯度消失了,只要保證有一條遠距離路徑(就是上面說的那條高速公路)梯度不消失,總的遠距離梯度就不會消失(正常梯度 + 消失梯度 = 正常梯度)。因此 LSTM 通過改善一條路徑上的梯度問題拯救了總體的遠距離梯度

6、同樣,因爲總的遠距離梯度 = 各條路徑的遠距離梯度之和,高速公路上梯度流比較穩定,但其他路徑上梯度有可能爆炸,此時總的遠距離梯度 = 正常梯度 + 爆炸梯度 = 爆炸梯度,因此 LSTM 仍然有可能發生梯度爆炸。不過,由於 LSTM 的其他路徑非常崎嶇,和普通 RNN 相比多經過了很多次激活函數(導數都小於 1),因此 LSTM 發生梯度爆炸的頻率要低得多。實踐中梯度爆炸一般通過梯度裁剪來解決。

7、對於現在常用的帶遺忘門的 LSTM 來說,6 中的分析依然成立,而 5 分爲兩種情況:其一是遺忘門接近 1(例如模型初始化時會把 forget bias 設置成較大的正數,讓遺忘門飽和),這時候遠距離梯度不消失;其二是遺忘門接近 0,但這時模型是故意阻斷梯度流的,這不是 bug 而是 feature(例如情感分析任務中有一條樣本 “A,但是 B”,模型讀到“但是”後選擇把遺忘門設置成 0,遺忘掉內容 A,這是合理的)。當然,常常也存在 f 介於 [0, 1] 之間的情況,在這種情況下只能說 LSTM 改善(而非解決)了梯度消失的狀況。

 
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章