LSTM詳解

前言

之前的文章講解了RNN的基本結構BPTT算法及梯度消失問題,說到了RNN無法解決長期依賴問題,本篇文章要講的LSTM很好地解決了這個問題。本文部分內容翻譯自Understanding LSTM Networks

文章分爲四個部分:

  • RNN與LSTM的對比
  • LSTM的核心思想
  • LSTM公式和結構詳解
  • LSTM變體介紹

一. RNN與LSTM對比

1.公式對比:

首先對RNN的公式做一下變形:
st=tanh(Wsst1+Wxxt+b)=tanh(W[st1,xt]+b)ot=softmax(Vst+c) \begin{aligned} s_t &=tanh(W_ss_{t-1}+W_xx_t+b)\\ &=tanh(W[s_{t-1},x_t]+b)\\ o_t &=softmax(Vs_t+c) \\ \end{aligned}

其中:[st1,xt][s_{t-1},x_t]表示把st1s_{t-1}xtx_t兩個向量連接成一個更長的向量。所以有W[st1,xt]=Wsst1+WxxtW[s_{t-1},x_t]=W_ss_{t-1}+W_xx_t,寫成矩陣乘法形式:
[W][st1xt]=[WsWx][st1xt]=Wsst1+Wxxt \begin{aligned} \begin{bmatrix}W\end{bmatrix}\begin{bmatrix}\mathbf{s}_{t-1}\\ \mathbf{x}_t\end{bmatrix}&= \begin{bmatrix}W_{s}&W_{x}\end{bmatrix}\begin{bmatrix}\mathbf{s}_{t-1}\\ \mathbf{x}_t\end{bmatrix}\\ &=W_{s}\mathbf{s}_{t-1}+W_{x}\mathbf{x}_t \end{aligned}

所以有:

RNN:

st=tanh(W[st1,xt]+b)ot=softmax(Vst+c) \begin{aligned} s_t &=tanh(W[s_{t-1},x_t]+b)\\ o_t &=softmax(Vs_t+c) \\ \end{aligned}

LSTM:

ft=σ(Wf[ht1,xt]+bf)           it=σ(Wi[ht1,xt]+bi)            ot=σ(Wo[ht1,xt]+oi)            C~t=tanh(WC[ht1,xt]+bC)    Ct=ftCt1+itC~t                 cell stateht=ottanh(Ct)                            \begin{aligned} f_t &=\sigma (W_f\cdot[h_{t-1},x_t]+b_f) \ \ \ \ \ \ \ \ \ \ \ 遺忘門\\ i_t &=\sigma (W_i\cdot[h_{t-1},x_t]+b_i) \ \ \ \ \ \ \ \ \ \ \ \ 輸入門 \\ o_t &=\sigma (W_o\cdot[h_{t-1},x_t]+o_i) \ \ \ \ \ \ \ \ \ \ \ \ 輸出門 \\ \widetilde{C}_t &=tanh(W_C\cdot [h_{t-1},x_t]+b_C) \ \ \ \ 候選值 \\ C_t &=f_t\cdot C_{t-1}+i_t\cdot \widetilde{C}_t \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ cell \ state\\ h_t &=o_t \cdot tanh(C_t) \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ 輸出值\\ \end{aligned}

2.結構對比

RNN的重複模塊中,只有一個tanh層

LSTM的重複模塊中,有四個層,多了三個門(gate)

在上面兩幅圖中,每條黑線都代表一個向量,從上一個節點輸出,輸入到下一個節點。粉色圓圈代表對每個元素的操作(比如點乘),黃色方框代表神經網絡層,兩條黑線合併代表向量拼接,一條黑線分爲兩條代表複製。

二. LSTM的核心思想

原始RNN的隱藏單元只有一個狀態,即RNN詳解中的sts_t,它對短期記憶敏感而對長期記憶不那麼敏感。而LSTM增加了一個狀態,即 CC ,用它來保存長期記憶,我們稱之爲單元狀態(cell state),下文中簡稱爲cell。
LSTM的核心就是多出來的這個cell state,下圖中的水平黑線代表cell state通過時間序列不斷向前傳送。傳送圖中只有少量的線性運算作用在cell state上,所以cell state可以存儲着信息並保持它們不怎麼變而傳送得很遠。這就是它能解決長期依賴問題的原因。

LSTM可以通過門(gate)來向cell state中添加信息或刪除信息。
門可以選擇性地讓信息通過,門的結構是用一個sigmoid層來點乘cell state:

sigmoid層輸出的值從0到1,這個值描述多少信息能通過。0表示啥也過不去,1表示啥都放過去。
LSTM一共有三個門,來幫助cell state遺忘、輸入、輸出。

三. 分四步詳細講解LSTM★

1.決定什麼信息需要從cell中丟棄掉。

通過構建一個遺忘門(forget gate):輸入當前時刻的xtx_t和上一時刻的輸出ht1h_{t-1},輸出一個和Ct1C_{t-1}同維度的向量,矩陣中每一個值都代表Ct1C_{t-1}中對應參數的去留情況,0代表徹底丟掉,1代表完全保留。
$ft=σ(Wf[ht1,xt]+bf)f_t=\sigma(W_f\cdot[h_{t-1},x_t]+b_f)


舉個例子:比如一個語言模型,根據之前的所有詞預測下一個詞。在這個問題中,cell可能已經記住了當前人物的性別,以便下次預測人稱代詞(他、她)時使用。但是當我們遇到一個新人物時,我們需要將舊人物的性別忘掉。

2.決定要往cell中存儲哪些新信息。

這一步有兩個部分:
a.通過構建一個輸入門(input gate),決定要更新哪些信息。
it=σ(Wi[ht1,xt]+bi)i_t =\sigma (W_i\cdot[h_{t-1},x_t]+b_i)

b.然後構建一個候選值向量(cell):C~t\widetilde{C}_t,之後會用輸入門點乘這個候選值向量,來選出要更新的信息。
C~t=tanh(WC[ht1,xt]+bC)\widetilde{C}_t=tanh(W_C\cdot [h_{t-1},x_t]+b_C)


在語言模型的例子中:這一步我們是想要把新人物的性別記住。

3.執行前兩步:遺忘舊的、保存新的。

這一步我們對舊cell Ct1C_{t-1}進行更新,變成新cell CtC_t
Ct=ftCt1+itC~tC_t =f_t\cdot C_{t-1}+i_t\cdot \widetilde{C}_t

Ct1C_{t-1} 點乘 ftf_t 代表我們丟棄掉要遺忘的信息。C~t\widetilde{C}_t 點乘iti_t代表我們從候選值向量中挑出要更新記住的信息。

在語言模型的例子中:這一步真正執行下面的操作:忘舊人物的性別,記住新人物的性別。

4.決定輸出什麼。

分爲兩步:
a.構建一個輸出門(output gate):決定要輸出哪些信息。
ot=σ(Wo[ht1,xt]+oi)o_t=\sigma (W_o\cdot[h_{t-1},x_t]+o_i)

b.將cell CtC_t 輸入 tanhtanh函數將所有參數值壓縮爲-1到1之間的值。然後將其點乘輸出門,輸出我們想輸出的部分。
ht=ottanh(Ct)h_t=o_t \cdot tanh(C_t)


在語言模型的例子中:比如剛看到一個人稱代詞he或they(cell狀態已經存儲),而下一個詞可能是一個動詞,那麼我們從人稱代詞(cell狀態)就可以看出下一個動詞的形式,比如(makes, make),he對應makes,they對應make。

四. LSTM的變體

上述的LSTM是最原始的LSTM,還有很多變體。

第一種變體Gers & Schmidhuber (2000)提出,這種變體添加了窺視孔連接(peephole connections)。具體操作就是每個門(gate)的輸入多加了cell state。

第二種變體是去掉輸入門(input gate)。不去分開決定遺忘什麼輸入什麼,而是一起做決定,只有要遺忘的值纔去對它們輸入更新。

第三種變體Cho, et al. (2014)提出,名爲GRU。它將遺忘門和輸入們簡化爲一個更新門,還將cell state和隱藏單元(hidden state)合併起來。結構相對LSTM更簡單,也很流行。

References

[1] Understanding LSTM Networks
[2] 零基礎入門深度學習(6) - 長短時記憶網絡(LSTM)
[3] Bengio的深度學習(花書)

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章