【機器學習】從RNN到Attention上篇 循環神經網絡RNN,門控循環神經網絡LSTM

打算寫一個從RNN到Attention的系列文章,今天先介紹一下循環神經網絡RNN和門控循環神經網絡LSTM,很多內容爲筆者自己的理解,難免有疏漏之處,歡迎大家探討。
文章有一些修改,因爲是在本人的知乎專欄裏劉改的,不想來回修改,大家可以去【從RNN到Attention】上篇 循環神經網絡RNN,門控循環神經網絡LSTM

一.爲什麼RNN比DNN更適合時間序列問題

DNN求解時序問題

對於一個時間序列問題,以單詞預測爲例,已知x1,x2,x3,,xtx_1,x_2,x_3,……,x_t,求解t時刻的單詞xt+1x_{t+1},那麼從概率的角度,該問題可以建模爲求解argmaxθP(xt+1x1,x2,....xtθ)argmax_{\theta}P(x_{t+1}|x_{1},x_2,....x_t,\theta),其中θ\theta爲模型參數。如果我們用DNN求解該問題,則模型輸入輸出可以分別表示爲
X=[x1,x2,x3,,xt1,xt]X=[x_1,x_2,x_3,……,x_{t-1},x_t]
Y=xt+1Y=x_{t+1}

似乎沒有什麼問題,但是假設一個單詞的維度爲dd,則XX的維度爲dtd*t,僅考慮從輸入到第一層隱藏層,且隱藏層的維度爲mm,那麼其中的參數總量爲dtmd*t*m,如下圖所示,隨着t的增長,參數量的增長是非常恐怖的,而且採用這種建模方式,x1,x2,x3,xtx_1,x_2,x_3,……x_t對於模型來說是等價的,丟失了他們的時序關係,因此DNN處理時序問題存在

  • 1.參數量過大
  • 2.丟失了時序關係
    DNN參數示意圖,自己畫的,有點醜

RNN求解時序問題

RNN的結構如圖表示
RNN網絡結構圖
其中xix_{i}爲輸入,對應單詞預測問題即爲單詞的向量表示,hih_{i}爲隱含層(hidden layer),是循環神經網絡中特有的網絡結構,其中
Ht=ϕ(XtWxh+Ht1Whh+bh).\boldsymbol{H}_t = \phi(\boldsymbol{X}_t \boldsymbol{W}{xh} + \boldsymbol{H}_{t-1} \boldsymbol{W}_{hh} + \boldsymbol{b}_h).
我們從上述式子可以看出:

  • 隱含狀態HtH_t與t時刻輸入xtx_t和上一時刻的隱含狀態Ht1H_{t-1}有關,而Ht1H_{t-1}也同樣與t-1時刻輸入xt1x_{t-1}和上上一時刻的隱含狀態Ht2H_{t-2}有關,以此類推,HtH_t可以作爲t時刻之前的輸入和隱藏狀態的信息儲藏,而由於更近的時刻信息儲藏的更加完整,從而既保留了之前的輸入信息,同時還保證了他們時序關係
  • XXHt1H_{t-1}分別通過兩個矩陣乘法與HtH_t相關聯。
  • 如果去掉Ht1Whh\boldsymbol{H}_{t-1} \boldsymbol{W}_{hh},則上式就是一個全連接。
  • 事實上,我們令Xt=[Xt,Ht1],W=[Wxh,Whh]X^{'}_t=[X_t,H_{t-1}],W^{'}=[W_{xh},W_{hh}],則上式可以改寫爲Ht=ϕ(XtW+bh)H_t= \phi(X^{'}_tW^{'}+b_h)我們可以通過全連接來實現RNN
  • 我們來看一下參數量,循環神經網絡中的隱含狀態與隱藏層作用類似,因此我們可以比較兩者的參數量大小,我們假定隱藏層的維度也爲m,首先忽略bhb_h因爲都是m維,則WxhW_{xh}的維度爲x的維度d*隱藏層的維度m,即dmd*mWhhW_{hh}的維度爲mmm*m,因此總的維度爲(d+m)m(d+m)*m,顯然遠遠小於DNN的dtmd*t*m且與tt的長度無關!理論上,我們可以將輸入的長度拉倒無限長。
  • 我們再來思考一下爲什麼循環神經網絡的參數量與tt的長度無關呢?因爲對於長度爲tt的輸入,他們共用了同一個WxhW_{xh}WhhW_{hh},大大減少了參數量。
  • 我們怎麼從隱藏層hth_t得到yty_t的呢?其實隱藏層hth_t的作用和DNN中的隱藏層作用類似,我們可以有很多處理方式,比如直接通過softmax求出yty_t的概率分佈,也可以作爲一個全連接層的輸入,再經過別的操作得到yty_t

二、門控循環神經網絡LSTM

從上面的介紹我們可以看出RNN的關鍵在於HtH_t保存之前的信息應用到當前的任務之上,但是HtH_t真的可以做到嗎?很難!當時間步距離較大時,循環神經網絡在反向傳播的過程中的梯度較容易出現衰減或爆炸(詳見通過時間反向傳播),LSTM(Long Short Term Memory)可以避免上述的長期依賴問題,由於GRU和LSTM類似,基本可以視爲LSTM的簡化版,在這裏就不做贅述。
LSTM的網絡結構圖如下所示:
圖片來自李沐老師《動手深度學習》
如果有小夥伴看過這張圖,不知道初次看的時候內心是什麼感受,反正我當時是一臉懵逼(臥槽,這什麼玩意兒?)仔細研究過後,我發現其實LSTM的整個網絡結構可以簡述爲“三門兩細胞”,我們依照這個主線來理解應該會更輕鬆一些,首先來看“三門”:記憶門,遺忘門和輸出門。
It=σ(XtWxi+Ht1Whi+bi) \begin{aligned} \boldsymbol{I}_t &= \sigma(\boldsymbol{X}_t \boldsymbol{W}{xi} + \boldsymbol{H}_{t-1} \boldsymbol{W}_{hi} + \boldsymbol{b}i) \end{aligned}
 Ft=σ(XtWxf+Ht1Whf+bf), \begin{aligned}\ \boldsymbol{F}_t &= \sigma(\boldsymbol{X}_t \boldsymbol{W}{xf} + \boldsymbol{H}_{t-1} \boldsymbol{W}{hf} + \boldsymbol{b}f),\end{aligned}
 Ot=σ(XtWxo+Ht1Who+bo), \begin{aligned}\ \boldsymbol{O}_t &= \sigma(\boldsymbol{X}_t \boldsymbol{W}{xo} + \boldsymbol{H}_{t-1} \boldsymbol{W}{ho} + \boldsymbol{b}_o), \end{aligned}
這三個門在之後的計算中分別承載了不同的物理意義,計算上和之前RNN中隱藏層的計算差不多,也就是矩陣運算+激活函數,同樣用到了前一時刻的隱含變量Ht1H_{t-1}和當前時刻的輸入XtX_t,事實上他們也都可以通過一個全連接表示。
“兩細胞”包括候選記憶細胞C~t\tilde{\boldsymbol{C}}_t和記憶細胞Ct\boldsymbol{C}_t
候選記憶細胞C~t\tilde{\boldsymbol{C}}t的表達式爲
C~t=tanh(XtWxc+Ht1Whc+bc)\tilde{\boldsymbol{C}}_t = \text{tanh}(\boldsymbol{X}_t \boldsymbol{W}{xc} + \boldsymbol{H}_{t-1} \boldsymbol{W}_{hc} + \boldsymbol{b}_c)
它的計算與上面介紹的3個門也類似,但使用了值域在 [−1,1] 的tanh函數作爲激活函數。候選記憶細胞C~t\tilde{\boldsymbol{C}}t的作用是作爲記憶細胞Ct\boldsymbol{C}_t的輸入
記憶細胞Ct\boldsymbol{C}_t的計算公式爲:
Ct=FtCt1+ItC~t\boldsymbol{C}_t = \boldsymbol{F}_t \odot \boldsymbol{C}_{t-1} + \boldsymbol{I}_t \odot \tilde{\boldsymbol{C}}_t
其中\odot爲點乘,此時我們發現在記憶細胞Ct\boldsymbol{C}_t的計算公式中,用到了遺忘門Ft\boldsymbol{F}_t,並且與前一時刻的記憶細胞Ct1\boldsymbol{C}_{t-1}做點乘,表達的物理含義是我們希望對之前記憶的遺忘程度,當遺忘門某維度近似1,則該維度上一時刻的記憶被傳遞到當前記憶細胞,反之則被遺忘
同樣的,對於輸入門It\boldsymbol{I}_t,並且與當前時刻的候選記憶細胞C~t\tilde{\boldsymbol{C}}_t做點乘,表達對於當前時刻的候選記憶細胞的接收程度,當輸入門某維度近似1,則當前時刻的候選記憶細胞的該維度信息被接收到當前記憶細胞,反之被忽略
我們再來做個比較,其實它和RNN的公式Ht=ϕ(XtWxh+Ht1Whh+bh).\boldsymbol{H}_t = \phi(\boldsymbol{X}_t \boldsymbol{W}{xh} + \boldsymbol{H}_{t-1} \boldsymbol{W}_{hh} + \boldsymbol{b}_h).很相似,Ft\boldsymbol{F}_t類似於Wt1\boldsymbol{W}_{t-1},都是對於歷史數據的處理,輸入門It\boldsymbol{I}_tWhh\boldsymbol{W}_{hh}類似,都是表達對於輸入的處理,不同的是Ft\boldsymbol{F}_tIt\boldsymbol{I}_t是做點乘,另外二者爲矩陣乘法。
最後隱藏層的輸出爲
Ht=Ottanh(Ct).\boldsymbol{H}_t = \boldsymbol{O}_t \odot \text{tanh}(\boldsymbol{C}_t).
同樣是點乘,Ot\boldsymbol{O}_t是物理含義是對於輸出的篩選,當輸出門某維度近似1時,記憶細胞將該維度的信息傳遞到隱藏層供輸出層使用;當輸出門近似0時,則該維度的信息無法傳遞到隱藏層
我們最後再總結一下LSTM的整個設計思想

  • 當前輸入XtX_t和前一時刻的隱含狀態Ht1H_{t-1}生成輸入門ItI_t、輸出門OtO_t和遺忘門FtF_t,以及候選記憶細胞C~t\tilde{\boldsymbol{C}}_t
  • 候選記憶細胞C~t\tilde{\boldsymbol{C}}_t和輸入門ItI_t控制當前時刻對於記憶細胞Ct\boldsymbol{C}_t輸入,遺忘門FtF_t和前一時刻的記憶細胞C~t1\tilde{\boldsymbol{C}}_{t-1}控制記憶細胞歷史時刻的輸入,注意這裏是點乘
  • 記憶細胞Ct\boldsymbol{C}_t和輸出門OtO_t控制隱藏層,注意這裏也是點乘
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章