LSTM:長短期記憶網絡 (Long short-term memory)

LSTM :Long short-term memory

這也是RNN的一個變種網絡,在之後大家都可以見到各類變種網絡,其本質就是爲了解決某個領域問題而設計出來的,LSTM是爲了解決RNN模型存在的問題而提出來的,RNN模型存在長序列訓練過程中梯度爆炸和梯度消失的問題,無法長久的保存歷史信息,而LSTM就可以解決梯度消失和梯度爆炸問題。簡單來說,就是相比普通的RNN,LSTM能夠在更長的序列中有更好的表現。

網絡結構

LSTM的RNN的更新模塊具有4個不同的層相互作用,這四個層分別是:

  • 遺忘門
    遺忘門
    σ\sigma是指sigmoid函數,對於狀態Ct1C_{t-1}矩陣當中每個輸入的值,都會乘以一個乘子,乘子的值在[0, 1]之間,相當於是決定了遺忘多少部分。如果乘子值爲1,說明全部保留,不刪除原本的記憶,如果是0,說明狀態Ct1C_{t-1}矩陣對應的這個值全部刪除,全部遺忘。場景:比如文本中的轉折語句,前一個句子主語“He”,名字叫“Peter”,國籍是“America”。下一句新出現了一羣人,因此這個時候狀態矩陣對應主語的這一欄就會刪除“He”,以保證接下來的動詞的形式不是第三人稱單數。

  • 輸入門
    輸入門
    這裏有兩部分同時進行:一個是σ\sigma函數決定添加多少部分的新信息到前一個狀態矩陣當中(類似於權重),tanh層則根據前一個的輸入值ht1h_{t-1}​和當前的輸入值xt1x_{t-1}​產生一個新的當前狀態(也就是一個新的候選值向量,這個向量之後要加入到已有的狀態矩陣當中)。最後根據前面σ\sigma函數輸出的權重和新的候選值向量兩個共同更新原有的矩陣。其實是構建一個權重、一個輸入,權重是對輸入做一個過濾判斷。
    在這裏插入圖片描述
    最後跟歷史的輸入做加法作爲CtC_t

  • 輸出門
    輸出門
    輸出層也有一個權重,這個權重也是σ\sigma函數對輸入值ht1h_{t-1}​和當前的輸入值xt1x_{t-1}​的作用,對應圖中的oto_t,然後對CtC_t做乘法,保證對輸出的一個過濾。其實最後一個輸出yy還要經過轉換:
    y^(t)=δ(Vht+c) \hat{y}^{(t)}=\delta(Vh_{t}+c)

反向傳播

通過上節,我們可以知道誤差來自兩個地方:ltl_{t}lt+1l_{t+1},一個是tt時刻的神經單元的誤差,一個是tt時刻之後的神經單元的誤差
L=lt+lt+1 L=l_t+l_{t+1}

其中有兩個隱藏變量:δh(t)\delta_{h}^{(t)}δc(t)\delta_{c}^{(t)}
δh(t)=Lht=ltht+lt+1ht=VT(y^tyt)+lt+1ht+1ht+1ht \begin{aligned} \delta_{h}^{(t)} = \frac{\partial L}{\partial h_{t}​} &=\frac{\partial l_t}{\partial h_{t}​}+\frac{\partial l_{t+1}}{\partial h_{t}​}\\ &=V^{T}(\hat{y}^{t}-y^{t})+\frac{\partial l_{t+1}}{\partial h_{t+1}​}\frac{\partial h_{t+1}}{\partial h_{t}​} \end{aligned}
重點是這個lt+1ht\frac{\partial l_{t+1}}{\partial h_{t}​}如何計算,ht+1=ot+1tanh(Ct+1)h_{t+1}=o_{t+1} \odot tanh(C_{t+1}),其中ot+1o_{t+1}Ct+1C_{t+1}都有關於hth_t的,Ct+1=Ctft+1+it+1C^t+1C_{t+1}=C_{t} \odot f_{t+1}+i_{t+1} \odot \hat {C}_{t+1}都有關於hth_t的遞推關係,求導就比較複雜了。首先這裏δ\delta是指sigmod函數,sigmod函數求導等於:f(x)(1f(x))f(x)(1-f(x)),tanh的導數爲:1f(x)21-f(x)^2lt+1ht\frac{\partial l_{t+1}}{\partial h_{t}​}導數拆解爲:
ht+1ot+1ot+1ht=ot+1(1ot+1)tanh(Ct+1)Wo \frac{\partial h_{t+1}}{\partial o_{t+1}​}\frac{\partial o_{t+1}}{\partial h_{t}​}=o_{t+1}(1-o_{t+1})\odot tanh(C_{t+1})W_o
ht+1tanht+1tanht+1ht\frac{\partial h_{t+1}}{\partial tanh_{t+1}​}\frac{\partial tanh_{t+1}}{\partial h_{t}​}的求導比較複雜,這裏需要拆解求導
ht+1tanht+1tanht+1Ct+1=ot+1(1tanh(Ct+1)2) \frac{\partial h_{t+1}}{\partial tanh_{t+1}​}\frac{\partial tanh_{t+1}}{\partial C_{t+1}​}=o_{t+1}(1-tanh(C_{t+1})^2)
這裏我們用一個變量C\bigtriangleup C來表示ht+1tanht+1tanht+1Ct+1\frac{\partial h_{t+1}}{\partial tanh_{t+1}​}\frac{\partial tanh_{t+1}}{\partial C_{t+1}​},因爲還需要對Ct+1C_{t+1}的變量中的hth_t來求導,避免公式太長,用一個變量來替換一下,然後分別求:
Ct+1ft+1=ft+1(1ft+1)CtWfCt+1it+1=C^t+1it+1(1it+12)WiCt+1C^t+1=it+1C^t+1(1C^t+12)Wa \frac{\partial C_{t+1}}{\partial f_{t+1}​} =f_{t+1}\odot (1-f_{t+1}) \odot C_t W_f\\ \frac{\partial C_{t+1}}{\partial i_{t+1}​} =\hat {C}_{t+1}\odot i_{t+1}(1-i_{t+1}^2)W_i\\ \frac{\partial C_{t+1}}{\partial \hat {C}_{t+1}​} =i_{t+1}\odot \hat {C}_{t+1}(1-\hat {C}_{t+1}^2)W_a
所以:
lt+1ht=ot+1(1ot+1)tanh(Ct+1)Wo+CCt+1ft+1+CCt+1it+1+CCt+1C^t+1 \frac{\partial l_{t+1}}{\partial h_{t}​} =o_{t+1}(1-o_{t+1})\odot tanh(C_{t+1})W_o+ \bigtriangleup{C} \frac{\partial C_{t+1}}{\partial f_{t+1}​}+\bigtriangleup{C}\frac{\partial C_{t+1}}{\partial i_{t+1}​}+\bigtriangleup{C}\frac{\partial C_{t+1}}{\partial \hat {C}_{t+1}​}
這裏主要參考了劉建平老師的博客,鏈接在下面,可以進去詳細看看。

LSTM 時長

誤差向上一個狀態傳遞時幾乎沒有衰減,所以權值調整的時候,對於很長時間之前的狀態帶來的影響和結尾狀態帶來的影響可以同時發揮作用,最後訓練出來的模型就具有較長時間範圍內的記憶功能。

lstm如何解決梯度消失

首先說明一下梯度爆炸的解決比較簡單,比如截斷,所以大部分網絡研究的問題在於梯度消失。RNN梯度消失帶來的問題是對遠距離的信息越來越弱,因爲梯度傳過去後很小,這樣遠距離信息都沒有起到作用,所以LSTM一方面有CtC_t,通過gate機制,將矩陣乘法變爲了逐位想乘,延緩了梯度消失,可以存儲足夠遠的信息,在反向推到的誤差傳遞過程中,很多推到參數是是0|1,這樣也保證了梯度的消失和爆炸會非常延緩。我們具體來看下如何解決: LSTM通過門機制就能夠解決梯度問題。

LSTM四倍與RNN的參數也是對網絡模型有幫助的,通過參數來控制模型。

缺點

引入了很多內容,導致參數變多,也使得訓練難度加大了很多。因此很多時候我們往往會使用效果和LSTM相當但參數更少的GRU來構建大訓練量的模型。

參考博客

LSTM如何來避免梯度彌散和梯度爆炸?
Why LSTMs Stop Your Gradients From Vanishing: A View from the Backwards Pass
LSTM模型與前向反向傳播算法

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章