LSTM :Long short-term memory
這也是RNN的一個變種網絡,在之後大家都可以見到各類變種網絡,其本質就是爲了解決某個領域問題而設計出來的,LSTM是爲了解決RNN模型存在的問題而提出來的,RNN模型存在長序列訓練過程中梯度爆炸和梯度消失的問題,無法長久的保存歷史信息,而LSTM就可以解決梯度消失和梯度爆炸問題。簡單來說,就是相比普通的RNN,LSTM能夠在更長的序列中有更好的表現。
網絡結構
LSTM的RNN的更新模塊具有4個不同的層相互作用,這四個層分別是:
-
遺忘門
σ是指sigmoid函數,對於狀態Ct−1矩陣當中每個輸入的值,都會乘以一個乘子,乘子的值在[0, 1]之間,相當於是決定了遺忘多少部分。如果乘子值爲1,說明全部保留,不刪除原本的記憶,如果是0,說明狀態Ct−1矩陣對應的這個值全部刪除,全部遺忘。場景:比如文本中的轉折語句,前一個句子主語“He”,名字叫“Peter”,國籍是“America”。下一句新出現了一羣人,因此這個時候狀態矩陣對應主語的這一欄就會刪除“He”,以保證接下來的動詞的形式不是第三人稱單數。
-
輸入門
這裏有兩部分同時進行:一個是σ函數決定添加多少部分的新信息到前一個狀態矩陣當中(類似於權重),tanh層則根據前一個的輸入值ht−1和當前的輸入值xt−1產生一個新的當前狀態(也就是一個新的候選值向量,這個向量之後要加入到已有的狀態矩陣當中)。最後根據前面σ函數輸出的權重和新的候選值向量兩個共同更新原有的矩陣。其實是構建一個權重、一個輸入,權重是對輸入做一個過濾判斷。
最後跟歷史的輸入做加法作爲Ct。
-
輸出門
輸出層也有一個權重,這個權重也是σ函數對輸入值ht−1和當前的輸入值xt−1的作用,對應圖中的ot,然後對Ct做乘法,保證對輸出的一個過濾。其實最後一個輸出y還要經過轉換:
y^(t)=δ(Vht+c)
反向傳播
通過上節,我們可以知道誤差來自兩個地方:lt和lt+1,一個是t時刻的神經單元的誤差,一個是t時刻之後的神經單元的誤差
L=lt+lt+1
其中有兩個隱藏變量:δh(t) 和δc(t)。
δh(t)=∂ht∂L=∂ht∂lt+∂ht∂lt+1=VT(y^t−yt)+∂ht+1∂lt+1∂ht∂ht+1
重點是這個∂ht∂lt+1如何計算,ht+1=ot+1⊙tanh(Ct+1),其中ot+1和Ct+1都有關於ht的,Ct+1=Ct⊙ft+1+it+1⊙C^t+1都有關於ht的遞推關係,求導就比較複雜了。首先這裏δ是指sigmod函數,sigmod函數求導等於:f(x)(1−f(x)),tanh的導數爲:1−f(x)2,∂ht∂lt+1導數拆解爲:
∂ot+1∂ht+1∂ht∂ot+1=ot+1(1−ot+1)⊙tanh(Ct+1)Wo
∂tanht+1∂ht+1∂ht∂tanht+1的求導比較複雜,這裏需要拆解求導
∂tanht+1∂ht+1∂Ct+1∂tanht+1=ot+1(1−tanh(Ct+1)2)
這裏我們用一個變量△C來表示∂tanht+1∂ht+1∂Ct+1∂tanht+1,因爲還需要對Ct+1的變量中的ht來求導,避免公式太長,用一個變量來替換一下,然後分別求:
∂ft+1∂Ct+1=ft+1⊙(1−ft+1)⊙CtWf∂it+1∂Ct+1=C^t+1⊙it+1(1−it+12)Wi∂C^t+1∂Ct+1=it+1⊙C^t+1(1−C^t+12)Wa
所以:
∂ht∂lt+1=ot+1(1−ot+1)⊙tanh(Ct+1)Wo+△C∂ft+1∂Ct+1+△C∂it+1∂Ct+1+△C∂C^t+1∂Ct+1
這裏主要參考了劉建平老師的博客,鏈接在下面,可以進去詳細看看。
LSTM 時長
誤差向上一個狀態傳遞時幾乎沒有衰減,所以權值調整的時候,對於很長時間之前的狀態帶來的影響和結尾狀態帶來的影響可以同時發揮作用,最後訓練出來的模型就具有較長時間範圍內的記憶功能。
lstm如何解決梯度消失
首先說明一下梯度爆炸的解決比較簡單,比如截斷,所以大部分網絡研究的問題在於梯度消失。RNN梯度消失帶來的問題是對遠距離的信息越來越弱,因爲梯度傳過去後很小,這樣遠距離信息都沒有起到作用,所以LSTM一方面有Ct,通過gate機制,將矩陣乘法變爲了逐位想乘,延緩了梯度消失,可以存儲足夠遠的信息,在反向推到的誤差傳遞過程中,很多推到參數是是0|1,這樣也保證了梯度的消失和爆炸會非常延緩。我們具體來看下如何解決: LSTM通過門機制就能夠解決梯度問題。
LSTM四倍與RNN的參數也是對網絡模型有幫助的,通過參數來控制模型。
缺點
引入了很多內容,導致參數變多,也使得訓練難度加大了很多。因此很多時候我們往往會使用效果和LSTM相當但參數更少的GRU來構建大訓練量的模型。
參考博客
LSTM如何來避免梯度彌散和梯度爆炸?
Why LSTMs Stop Your Gradients From Vanishing: A View from the Backwards Pass
LSTM模型與前向反向傳播算法