LSTM與梯度消失

1. 標準RNN中處理序列數據的方法是將上一個state的信息傳到下一個state中,表示成數學公式爲st=f(W*(st-1,xt)+b),其中f爲激活函數。在反向傳播中,根據求導的鏈式法則,這種形式求得的梯度爲一個矩陣W與激活函數導數的乘積。如果進行n次反向傳播,梯度變化將會變爲(W*f”)的n次方累乘。

(1)如果乘積大於1,則梯度會隨着反向傳播層數n的增加而成指數增長,導致梯度爆炸;

(2)如果乘積小於1,經過多層傳播後,小於1的數累乘後結果趨近於0,導致梯度消失。

2. 現在的lstm主要用來緩解梯度消失問題,其中,St主要由兩部分組成,表示過去信息的St-1和表示現在信息St(~),數學公式爲St = ft*St-1+it*f(St-1)

其中第二input部分與標準RNN類似,在反向傳播中可能會逐漸消失,而第一forget部分經過多次傳播之後,會出現ft*ft-1*ft-2……這樣的累乘,ft的大小是可選的,可以有效減輕梯度消失。

3. 梯度爆炸一般靠裁剪後的優化算法即可解決,比如gradient clipping(如果梯度的範數大於某個給定值,將梯度同比收縮)。

4. 關於lstm中的forget gate的理解:

(1)原始的lstm是沒有forget gate的,或者說相當於forget gate恆爲1,所以不存在梯度消失問題

(2)現在的lstm被引入了forget gate,但是lstm的一個初始化技巧就是將forget gate的bias置爲正數(例如1或5,這點可以查看各大框架代碼),這樣模型剛開始訓練時forget gate的值接近於1,不回發生梯度消失

(3)隨着訓練過程的進行,forget gate就不再恆爲1了。不過對於一個已經訓練好的模型,需要選擇性地記住或者遺忘某些信息,所以forget gate要麼是1,要麼是0,很少有類似0.5這樣的中間值,相當於一個二元的開關。例如在某個序列裏,forget gate全爲1,那麼梯度不會消失;否則,若某一個forget gate是0,這時候雖然會導致梯度消失,但是體現了模型的選擇性,刻意遺忘某些信息。

綜上,lstm可以緩解梯度消失問題,但不能徹底避免。

 

 

 

 

 

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章