循環神經網絡2--LSTM

這周在看循環數據網絡, 發現一個博客, 裏面推導極其詳細, 藉此記錄重點.

詳細推導

強烈建議手推一遍, 雖然會花一點時間, 但便於理清思路.

長短時記憶網絡

回顧BPTT算法裏誤差項沿時間反向傳播的公式:

(1)δkT=δtTi=kt1diag[f(neti)]W

根據範數的性質, 來獲取δkT 的模的上界:
(2)δkTδtTi=kt1diag[f(neti)]W(3)δtT(βfβW)tk

可以看到, 誤差項δ 從t時刻傳遞到k時刻, 其值上界是βfβw 的指數函數. βfβw 分別是對角矩陣diag[f(neti)] 和矩陣W模的上界. 顯然, 當t-k很大時, 會有梯度爆炸, 當t-k很小時, 會有梯度消失.

爲了解決RNN的梯度爆炸和梯度消失的問題, 就出現了長短時記憶網絡(Long Short Memory Network, LSTM). 原始RNN的隱藏層只有一個狀態h, 它對於短期的輸入非常敏感. 如果再增加一個狀態c, 讓它來保存長期的狀態, 那麼就可以解決原始RNN無法處理長距離依賴的問題.

img

新增加的狀態c, 稱爲單元狀態(cell state). 上圖按照時間維度展開:

img

上圖中, 在t時刻, LSTM的輸入有三個: 當前時刻網絡的輸入值xt , 上一時刻LSTM的輸出值ht1 , 以及上一時刻的單元狀態ct1 ; LSTM的輸出有兩個: 當前時刻的LSTM輸出ht , 當前時刻的狀態ct . 其中x,h,c 都是向量.

LSTM的關鍵在於怎樣控制長期狀態c. 在這裏, LSTM的思路是使用三個控制開關:

第一個開關, 負責控制繼續保存長期狀態c; (遺忘門)

第二個開關, 負責控制把即時狀態輸入到長期狀態c; (輸入門)

第三個開關, 負責控制是都把長期狀態c作爲當前的LSTM的輸出. (輸出門)

img

接下來, 具體描述一下輸出h和單元狀態c的計算方法.

長短時記憶網絡的前向計算

開關在算法中用門(gate)實現. 門實際上就是一層全連接層, 它的輸入是一個向量, 輸出是一個0~1的實數向量. 假設w是門的權重向量, b是偏置項, 門可以表示爲:

g(x)=σ(Wx+b)

門的使用, 就是用門的輸出向量按元素乘以我們需要控制的那個向量. 當門的輸出爲0時, 任何向量與之相乘都會得到0向量, 相當於什麼都不能通過; 當輸出爲1時, 任何向量與之相乘都爲本身, 相當於什麼都可以通過. 上式中σ 是sigmoid函數, 值域爲(0,1), 所以門的狀態是半開半閉的.

LSTM用兩個門來控制單元狀態c的內容, 一個是遺忘門(forget gate), 它決定了上一時刻的單元狀態ct1 有多少保留到當前時刻ct ; 另一個是輸入門(input gate), 它決定了當前時刻網絡的輸入xt 有多少保存到單元狀態ct . LSTM用輸出門(output gate)來控制單元狀態ct 有多少輸出到LSTM的當前輸出值ht .

1. 遺忘門:

ft=σ(Wf[ht1,xt]+bf)(1)

上式中, Wf 是遺忘門的權重矩陣, [ht1,xt] 表示把兩個向量連接到一個更長的向量, bf 是遺忘門的偏置項, σ 是sigmoid函數. 如果輸入的維度是dh , 單元狀態的維度是dc (通常dc=dh ), 則遺忘門的權重矩陣Wf 維度是dc×(dh+dx) .

事實上, 權重矩陣Wf 都是兩個矩陣拼接而成的: 一個是Wfh , 它對應着輸入項ht1 , 其維度爲dc×dh ; 一個是Wfx , 它對應着輸入項xt , 其維度爲dc×dh . Wf 可以寫成:

(4)[Wf][ht1xt]=[WfhWfx][ht1xt](5)=Wfhht1+Wfxxt

下圖是遺忘門的計算:

img

2. 輸入門:

it=σ(Wi[ht1,xt]+bi)(2)

上式中, Wi 是輸入門的權重矩陣, bi 是輸入門的偏置項.

下圖是輸入門的計算:

img

接下來, 計算用於描述當前輸入的單元狀態c~t , 它是根據根據上一次的輸出和本次的輸入來計算的:

c~t=tanh(Wc[ht1,xt]+bc)(3)

下圖是c~t 的計算:

img

現在, 我們計算當前時刻的單元狀態ct . 它是由上一次的單元狀態ct1 按元素乘以遺忘門ft , 再用當前輸入的單元狀態c~t 按元素乘以輸入門it , 再將兩個積加和產生的:

ct=ftct1+itc~t(4)

符號 表示按元素乘. 下圖是ct 的計算:

img

這樣, 就把LSTM關於當前的記憶c~t 和長期的記憶ct1 組合在一起, 形成了新的單元狀態ct . 由於遺忘門的控制, 它可以保存很久之前的信息, 由於輸入門的控制, 它又可以避免當前無關緊要的內容進入記憶.

3. 輸出門

ot=σ(Wo[ht1,xt]+bo)(5)

下圖表示輸出門的計算:

img

LSTM最終的輸出, 是由輸出門和單元狀態共同確定的:

ht=ottanh(ct)(6)

下圖表示LSTM最終輸出的計算:

img

式1式6就是LSTM前向計算的全部公式.

長短時記憶網絡的訓練

訓練部分比前向計算部分複雜, 具體推導如下.

LSTM訓練算法框架

LSTM的訓練算法仍然是反向傳播算法, 主要是三個步驟:

  1. 前向計算每個神經元的輸出值, 對於LSTM來說, 即ft,it,ctot,ht 五個向量的值;
  2. 反向計算每個神經元的誤差項δ 值, 與RNN一樣, LSTM誤差項的反向傳播也是包括兩個方向: 一個沿時間的反向傳播, 即從當前t時刻開始, 計算每個時刻的誤差項; 一個是將誤差項向上一層傳播;
  3. 根據相應的誤差項, 計算每個權重的梯度.

關於公式和符號的說明

接下來的推導, 設定gate的激活函數爲sigmoid, 輸出的激活函數爲tanh函數. 他們的導數分別爲:

(6)σ(z)=y=11+ez(7)σ(z)=y(1y)(8)tanh(z)=y=ezezez+ez(9)tanh(z)=1y2

從上式知, sigmoid函數和tanh函數的導數都是原函數的函數, 那麼計算出原函數的值, 導數便也計算出來.

LSTM需要學習的參數共有8組, 權重矩陣的兩部分在反向傳播中使用不同的公式, 分別是:

  1. 遺忘門的權重矩陣Wf 和偏置項bt , Wf 分開爲兩個矩陣WfhWfx
  2. 輸入門的權重矩陣Wi 和偏置項bi , Wi 分開爲兩個矩陣WihWxi
  3. 輸出門的權重矩陣Wo 和偏置項bo , Wo 分開爲兩個矩陣WohWox
  4. 計算單元狀態的權重矩陣Wc 和偏置項bc , Wc 分開爲兩個矩陣WchWcx

按元素乘 符號. 當 作用於兩個向量時, 運算如下:

ab=[a1a2a3...an][b1b2b3...bn]=[a1b1a2b2a3b3...anbn]

作用於一個向量一個矩陣時, 運算如下:
(10)aX=[a1a2a3...an][x11x12x13...x1nx21x22x23...x2nx31x32x33...x3n...xn1xn2xn3...xnn](11)=[a1x11a1x12a1x13...a1x1na2x21a2x22a2x23...a2x2na3x31a3x32a3x33...a3x3n...anxn1anxn2anxn3...anxnn]

作用於兩個矩陣時, 兩個矩陣對應位置的元素相乘. 按元素乘可以在某些情況下簡化矩陣和向量運算.

例如, 當一個對角矩陣右乘一個矩陣時, 相當於用對角矩陣的對角線組成的向量按元素乘那個矩陣:

diag[a]X=aX

當一個行向量左乘一個對角矩陣時, 相當於這個行向量按元素乘那個矩陣對角組成的向量:
aTdiag[b]=ab

在t時刻, LSTM的輸出值爲ht . 我們定義t時刻的誤差項δt 爲:
δt=defEht

這裏假設誤差項是損失函數對輸出值的導數, 而不是對加權輸出nettl 的導數. 因爲LSTM有四個加權輸入, 分別對應ft,it,ct,ot , 我們希望往上一層傳遞一個誤差項而不是四個, 但需要定義這四個加權輸入以及它們對應的誤差項.
(12)netf,t=Wf[ht1,xt]+bf(13)=Wfhht1+Wfxxt+bf(14)neti,t=Wi[ht1,xt]+bi(15)=Wihht1+Wixxt+bi(16)netc~,t=Wc[ht1,xt]+bc(17)=Wchht1+Wcxxt+bc(18)neto,t=Wo[ht1,xt]+bo(19)=Wohht1+Woxxt+bo(20)δf,t=defEnetf,t(21)δi,t=defEneti,t(22)δc~,t=defEnetc~,t(23)δo,t=defEneto,t

誤差項沿時間的反向傳遞

沿時間反向傳遞誤差項, 就是要計算出t-1時刻的誤差項δt1 .

(24)δt1T=Eht1(25)=Ehththt1(26)=δtThtht1

其中, htht1 是一個Jacobian矩陣, 爲了求出它, 需要列出ht 的計算公式, 即前面的式6式4:
ht=ottanh(ct)(6)ct=ftct1+itc~t(4)

顯然, ot,ft,it,c~t 都是ht1 的函數, 那麼, 利用全導數公式可得:
(27)δtThtht1=δtThtototneto,tneto,tht1+δtThtctctftftnetf,tnetf,tht1(28)+δtThtctctititneti,tneti,tht1+δtThtctctc~tc~tnetc~,tnetc~,tht1(29)=δo,tTneto,tht1+δf,tTnetf,tht1+δi,tTneti,tht1+δc~,tTnetc~,tht1(7)

下面, 要把式7中的每個偏導數都求出來, 根據式6, 可以求出:
(30)htot=diag[tanh(ct)](31)htct=diag[ot(1tanh(ct)2)]

根據式4, 可以求出:
(32)ctft=diag[ct1](33)ctit=diag[c~t](34)ctc~t=diag[it]

因爲:
(35)ot=σ(neto,t)(36)neto,t=Wohht1+Woxxt+bo(37)(38)ft=σ(netf,t)(39)netf,t=Wfhht1+Wfxxt+bf(40)(41)it=σ(neti,t)(42)neti,t=Wihht1+Wixxt+bi(43)(44)c~t=tanh(netc~,t)(45)netc~,t=Wchht1+Wcxxt+bc

可以得出:
(46)otneto,t=diag[ot(1ot)](47)neto,tht1=Woh(48)ftnetf,t=diag[ft(1ft)](49)netf,tht1=Wfh(50)itneti,t=diag[it(1it)](51)neti,tht1=Wih(52)c~tnetc~,t=diag[1c~t2](53)netc~,tht1=Wch

將上述偏導數導入到式7, 可以得到:
(54)δt1=δo,tTneto,tht1+δf,tTnetf,tht1+δi,tTneti,tht1+δc~,tTnetc~,tht1(55)=δo,tTWoh+δf,tTWfh+δi,tTWih+δc~,tTWch(8)

根據δo,t,δf,t,δi,t,δc~,t 的定義, 可知:
(56)δo,tT=δtTtanh(ct)ot(1ot)(9)(57)δf,tT=δtTot(1tanh(ct)2)ct1ft(1ft)(10)(58)δi,tT=δtTot(1tanh(ct)2)c~tit(1it)(11)(59)δc~,tT=δtTot(1tanh(ct)2)it(1c~2)(12)

式8式12就是將誤差沿時間反向傳播一個時刻的公式. 有了它, 便可以寫出將誤差項傳遞到任意k時刻的公式:
δkT=j=kt1δo,jTWoh+δf,jTWfh+δi,jTWih+δc~,jTWch(13)

將誤差項傳遞到上一層

假設當前是第l 層, 定義l1 層的誤差項是誤差函數對l1加權輸入的導數, 即:

δtl1=defEnettl1

本次LSTM的輸入xt 由下面的公式計算:
xtl=fl1(nettl1)

上式中, fl1 表示第l1激活函數.

因爲netf,tl,neti,tl,netc~,tl,neto,tl 都是xt 的函數, xt 又是nettl1 的函數, 因此, 要求出Enettl1 的導數, 就需要使用全導數公式:

(60)Enettl1=Enetf,tlnetf,tlxtlxtlnettl1+Eneti,tlneti,tlxtlxtlnettl1(61)+Enetc~,tlnetc~,tlxtlxtlnettl1+Eneto,tlneto,tlxtlxtlnettl1(62)=δf,tTWfxf(nettl1)+δi,tTWixf(nettl1)+δc~,tTWcxf(nettl1)+δo,tTWoxf(nettl1)(63)=(δf,tTWfx+δi,tTWix+δc~,tTWcx+δo,tTWox)f(nettl1)(14)

式14就是將誤差傳遞到上一層的公式.

權重梯度的計算

對於Wfh,Wih,Wch,Woh 的權重梯度, 我們知道它的梯度是各個時刻梯度之和. 我們首先求出它們在t時刻的梯度, 然後再求出他們最終的梯度.

我們已經求得了誤差項δo,t,δf,t,δi,t,δc~,t , 很容易求出t時刻的Woh,Wfh,Wih,Wch :

(64)EWoh,t=Eneto,tneto,tWoh,t(65)=δo,tht1T(66)(67)EWfh,t=Enetf,tnetf,tWfh,t(68)=δf,tht1T(69)(70)EWih,t=Eneti,tneti,tWih,t(71)=δi,tht1T(72)(73)EWch,t=Enetc~,tnetc~,tWch,t(74)=δc~,tht1T

將各個時刻的梯度加在一起, 就能得到最終的梯度:

(75)EWoh=j=1tδo,jhj1T(76)EWfh=j=1tδf,jhj1T(77)EWih=j=1tδi,jhj1T(78)EWch=j=1tδc~,jhj1T

對於偏置項bf,bi,bc,bo 的梯度, 先求出各個時刻的偏置項梯度:
(79)Ebo,t=Eneto,tneto,tbo,t(80)=δo,t(81)(82)Ebf,t=Enetf,tnetf,tbf,t(83)=δf,t(84)(85)Ebi,t=Eneti,tneti,tbi,t(86)=δi,t(87)(88)Ebc,t=Enetc~,tnetc~,tbc,t(89)=δc~,t

將各個時刻的偏置項梯度加在一起:
(90)Ebo=j=1tδo,j(91)Ebi=j=1tδi,j(92)Ebf=j=1tδf,j(93)Ebc=j=1tδc~,j

對於Wfx,Wix,Wcx,Wox 的權重梯度, 只需要根據相應的誤差項直接計算即可:
(94)EWox=Eneto,tneto,tWox(95)=δo,txtT(96)(97)EWfx=Enetf,tnetf,tWfx(98)=δf,txtT(99)(100)EWix=Eneti,tneti,tWix(101)=δi,txtT(102)(103)EWcx=Enetc~,tnetc~,tWcx(104)=δc~,txtT

以上就是LSTM的訓練算法的全部公式

GRU

上面所述是一種普通的LSTM, 事實上LSTM存在很多變體, GRU就是其中一種最成功的變體. 它對LSTM做了很多簡化, 同時保持和LSTM相同的效果.

GRU對LSTM做了兩大改動:

  1. 將輸入門, 遺忘門, 輸出門變爲兩個門: 更新門(Update Gate) zt 和充值門(Reset Gate) rt .
  2. 將單元狀態與輸出合併爲一個狀態: h

GRU的前向計算公式爲:

(105)zt=σ(Wz[ht1,xt])(106)rt=σ(Wr[ht1,xt])(107)h~t=tanh(W[rtht1,xt])(108)h=(1zt)ht1+zth~t

下圖是GRU的示意圖:

img

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章