這周在看循環數據網絡, 發現一個博客, 裏面推導極其詳細, 藉此記錄重點.
詳細推導
強烈建議手推一遍, 雖然會花一點時間, 但便於理清思路.
長短時記憶網絡
回顧BPTT算法裏誤差項沿時間反向傳播的公式:
δTk=δTt∏i=kt−1diag[f′(neti)]W(1)
根據範數的性質, 來獲取
δTk 的模的上界:
‖δTk‖⩽⩽‖δTt‖∏i=kt−1‖diag[f′(neti)]‖‖W‖‖δTt‖(βfβW)t−k(2)(3)
可以看到, 誤差項
δ 從t時刻傳遞到k時刻, 其值上界是
βfβw 的指數函數.
βfβw 分別是對角矩陣
diag[f′(neti)] 和矩陣W模的上界. 顯然, 當t-k很大時, 會有
梯度爆炸, 當t-k很小時, 會有
梯度消失.
爲了解決RNN的梯度爆炸和梯度消失的問題, 就出現了長短時記憶網絡(Long Short Memory Network, LSTM). 原始RNN的隱藏層只有一個狀態h, 它對於短期的輸入非常敏感. 如果再增加一個狀態c, 讓它來保存長期的狀態, 那麼就可以解決原始RNN無法處理長距離依賴的問題.
新增加的狀態c, 稱爲單元狀態(cell state). 上圖按照時間維度展開:
上圖中, 在t時刻, LSTM的輸入有三個: 當前時刻網絡的輸入值xt , 上一時刻LSTM的輸出值ht−1 , 以及上一時刻的單元狀態ct−1 ; LSTM的輸出有兩個: 當前時刻的LSTM輸出ht , 當前時刻的狀態ct . 其中x,h,c 都是向量.
LSTM的關鍵在於怎樣控制長期狀態c. 在這裏, LSTM的思路是使用三個控制開關:
第一個開關, 負責控制繼續保存長期狀態c; (遺忘門)
第二個開關, 負責控制把即時狀態輸入到長期狀態c; (輸入門)
第三個開關, 負責控制是都把長期狀態c作爲當前的LSTM的輸出. (輸出門)
接下來, 具體描述一下輸出h和單元狀態c的計算方法.
長短時記憶網絡的前向計算
開關在算法中用門(gate)實現. 門實際上就是一層全連接層, 它的輸入是一個向量, 輸出是一個0~1的實數向量. 假設w是門的權重向量, b是偏置項, 門可以表示爲:
g(x)=σ(Wx+b)
門的使用, 就是
用門的輸出向量按元素乘以我們需要控制的那個向量. 當門的輸出爲0時, 任何向量與之相乘都會得到0向量, 相當於什麼都不能通過; 當輸出爲1時, 任何向量與之相乘都爲本身, 相當於什麼都可以通過. 上式中
σ 是sigmoid函數, 值域爲(0,1), 所以門的狀態是半開半閉的.
LSTM用兩個門來控制單元狀態c的內容, 一個是遺忘門(forget gate), 它決定了上一時刻的單元狀態ct−1 有多少保留到當前時刻ct ; 另一個是輸入門(input gate), 它決定了當前時刻網絡的輸入xt 有多少保存到單元狀態ct . LSTM用輸出門(output gate)來控制單元狀態ct 有多少輸出到LSTM的當前輸出值ht .
1. 遺忘門:
ft=σ(Wf⋅[ht−1,xt]+bf)(式1)
上式中,
Wf 是遺忘門的權重矩陣,
[ht−1,xt] 表示把兩個向量連接到一個更長的向量,
bf 是遺忘門的偏置項,
σ 是sigmoid函數. 如果輸入的維度是
dh , 單元狀態的維度是
dc (通常
dc=dh ), 則遺忘門的權重矩陣
Wf 維度是
dc×(dh+dx) .
事實上, 權重矩陣Wf 都是兩個矩陣拼接而成的: 一個是Wfh , 它對應着輸入項ht−1 , 其維度爲dc×dh ; 一個是Wfx , 它對應着輸入項xt , 其維度爲dc×dh . Wf 可以寫成:
[Wf][ht−1xt]=[WfhWfx][ht−1xt]=Wfhht−1+Wfxxt(4)(5)
下圖是遺忘門的計算:
2. 輸入門:
it=σ(Wi⋅[ht−1,xt]+bi)(式2)
上式中,
Wi 是輸入門的權重矩陣,
bi 是輸入門的偏置項.
下圖是輸入門的計算:
接下來, 計算用於描述當前輸入的單元狀態c̃t , 它是根據根據上一次的輸出和本次的輸入來計算的:
c̃t=tanh(Wc⋅[ht−1,xt]+bc)(式3)
下圖是
c̃t 的計算:
現在, 我們計算當前時刻的單元狀態ct . 它是由上一次的單元狀態ct−1 按元素乘以遺忘門ft , 再用當前輸入的單元狀態c̃t 按元素乘以輸入門it , 再將兩個積加和產生的:
ct=ft∘ct−1+it∘c̃t(式4)
符號
∘ 表示
按元素乘. 下圖是
ct 的計算:
這樣, 就把LSTM關於當前的記憶c̃t 和長期的記憶ct−1 組合在一起, 形成了新的單元狀態ct . 由於遺忘門的控制, 它可以保存很久之前的信息, 由於輸入門的控制, 它又可以避免當前無關緊要的內容進入記憶.
3. 輸出門
ot=σ(Wo⋅[ht−1,xt]+bo)(式5)
下圖表示輸出門的計算:
LSTM最終的輸出, 是由輸出門和單元狀態共同確定的:
ht=ot∘tanh(ct)(式6)
下圖表示LSTM最終輸出的計算:
式1到式6就是LSTM前向計算的全部公式.
長短時記憶網絡的訓練
訓練部分比前向計算部分複雜, 具體推導如下.
LSTM訓練算法框架
LSTM的訓練算法仍然是反向傳播算法, 主要是三個步驟:
- 前向計算每個神經元的輸出值, 對於LSTM來說, 即ft,it,ctot,ht 五個向量的值;
- 反向計算每個神經元的誤差項δ 值, 與RNN一樣, LSTM誤差項的反向傳播也是包括兩個方向: 一個沿時間的反向傳播, 即從當前t時刻開始, 計算每個時刻的誤差項; 一個是將誤差項向上一層傳播;
- 根據相應的誤差項, 計算每個權重的梯度.
關於公式和符號的說明
接下來的推導, 設定gate的激活函數爲sigmoid, 輸出的激活函數爲tanh函數. 他們的導數分別爲:
σ(z)σ′(z)tanh(z)tanh′(z)=y=11+e−z=y(1−y)=y=ez−e−zez+e−z=1−y2(6)(7)(8)(9)
從上式知, sigmoid函數和tanh函數的導數都是原函數的函數, 那麼計算出原函數的值, 導數便也計算出來.
LSTM需要學習的參數共有8組, 權重矩陣的兩部分在反向傳播中使用不同的公式, 分別是:
- 遺忘門的權重矩陣Wf 和偏置項bt , Wf 分開爲兩個矩陣Wfh 和Wfx
- 輸入門的權重矩陣Wi 和偏置項bi , Wi 分開爲兩個矩陣Wih 和Wxi
- 輸出門的權重矩陣Wo 和偏置項bo , Wo 分開爲兩個矩陣Woh 和Wox
- 計算單元狀態的權重矩陣Wc 和偏置項bc , Wc 分開爲兩個矩陣Wch 和Wcx
按元素乘∘ 符號. 當∘ 作用於兩個向量時, 運算如下:
a∘b=⎡⎣⎢⎢⎢⎢⎢a1a2a3...an⎤⎦⎥⎥⎥⎥⎥∘⎡⎣⎢⎢⎢⎢⎢b1b2b3...bn⎤⎦⎥⎥⎥⎥⎥=⎡⎣⎢⎢⎢⎢⎢a1b1a2b2a3b3...anbn⎤⎦⎥⎥⎥⎥⎥
當
∘ 作用於
一個向量和
一個矩陣時, 運算如下:
a∘X=⎡⎣⎢⎢⎢⎢⎢a1a2a3...an⎤⎦⎥⎥⎥⎥⎥∘⎡⎣⎢⎢⎢⎢⎢x11x21x31xn1x12x22x32xn2x13x23x33...xn3............x1nx2nx3nxnn⎤⎦⎥⎥⎥⎥⎥=⎡⎣⎢⎢⎢⎢⎢a1x11a2x21a3x31anxn1a1x12a2x22a3x32anxn2a1x13a2x23a3x33...anxn3............a1x1na2x2na3x3nanxnn⎤⎦⎥⎥⎥⎥⎥(10)(11)
當
∘ 作用於
兩個矩陣時, 兩個矩陣對應位置的元素相乘. 按元素乘可以在某些情況下簡化矩陣和向量運算.
例如, 當一個對角矩陣右乘一個矩陣時, 相當於用對角矩陣的對角線組成的向量按元素乘那個矩陣:
diag[a]X=a∘X
當一個行向量左乘一個對角矩陣時, 相當於這個行向量按元素乘那個矩陣對角組成的向量:
aTdiag[b]=a∘b
在t時刻, LSTM的輸出值爲
ht . 我們定義t時刻的誤差項
δt 爲:
δt=def∂E∂ht
這裏假設誤差項是損失函數對輸出值的導數, 而不是對加權輸出
netlt 的導數. 因爲LSTM有四個加權輸入, 分別對應
ft,it,ct,ot , 我們希望往上一層傳遞一個誤差項而不是四個, 但需要定義這四個加權輸入以及它們對應的誤差項.
netf,tneti,tnetc̃,tneto,tδf,tδi,tδc̃,tδo,t=Wf[ht−1,xt]+bf=Wfhht−1+Wfxxt+bf=Wi[ht−1,xt]+bi=Wihht−1+Wixxt+bi=Wc[ht−1,xt]+bc=Wchht−1+Wcxxt+bc=Wo[ht−1,xt]+bo=Wohht−1+Woxxt+bo=def∂E∂netf,t=def∂E∂neti,t=def∂E∂netc̃,t=def∂E∂neto,t(12)(13)(14)(15)(16)(17)(18)(19)(20)(21)(22)(23)
誤差項沿時間的反向傳遞
沿時間反向傳遞誤差項, 就是要計算出t-1時刻的誤差項δt−1 .
δTt−1=∂E∂ht−1=∂E∂ht∂ht∂ht−1=δTt∂ht∂ht−1(24)(25)(26)
其中,
∂ht∂ht−1 是一個Jacobian矩陣, 爲了求出它, 需要列出
ht 的計算公式, 即前面的
式6和
式4:
ht=ot∘tanh(ct)(式6)ct=ft∘ct−1+it∘c̃t(式4)
顯然,
ot,ft,it,c̃t 都是
ht−1 的函數, 那麼, 利用全導數公式可得:
δTt∂ht∂ht−1=δTt∂ht∂ot∂ot∂neto,t∂neto,t∂ht−1+δTt∂ht∂ct∂ct∂ft∂ft∂netf,t∂netf,t∂ht−1+δTt∂ht∂ct∂ct∂it∂it∂neti,t∂neti,t∂ht−1+δTt∂ht∂ct∂ct∂c̃t∂c̃t∂netc̃,t∂netc̃,t∂ht−1=δTo,t∂neto,t∂ht−1+δTf,t∂netf,t∂ht−1+δTi,t∂neti,t∂ht−1+δTc̃,t∂netc̃,t∂ht−1(式7)(27)(28)(29)
下面, 要把
式7中的每個偏導數都求出來, 根據
式6, 可以求出:
∂ht∂ot∂ht∂ct=diag[tanh(ct)]=diag[ot∘(1−tanh(ct)2)](30)(31)
根據
式4, 可以求出:
∂ct∂ft∂ct∂it∂ct∂c̃t=diag[ct−1]=diag[c̃t]=diag[it](32)(33)(34)
因爲:
otneto,tftnetf,titneti,tc̃tnetc̃,t=σ(neto,t)=Wohht−1+Woxxt+bo=σ(netf,t)=Wfhht−1+Wfxxt+bf=σ(neti,t)=Wihht−1+Wixxt+bi=tanh(netc̃,t)=Wchht−1+Wcxxt+bc(35)(36)(37)(38)(39)(40)(41)(42)(43)(44)(45)
可以得出:
∂ot∂neto,t∂neto,t∂ht−1∂ft∂netf,t∂netf,t∂ht−1∂it∂neti,t∂neti,t∂ht−1∂c̃t∂netc̃,t∂netc̃,t∂ht−1=diag[ot∘(1−ot)]=Woh=diag[ft∘(1−ft)]=Wfh=diag[it∘(1−it)]=Wih=diag[1−c̃2t]=Wch(46)(47)(48)(49)(50)(51)(52)(53)
將上述偏導數導入到
式7, 可以得到:
δt−1=δTo,t∂neto,t∂ht−1+δTf,t∂netf,t∂ht−1+δTi,t∂neti,t∂ht−1+δTc̃,t∂netc̃,t∂ht−1=δTo,tWoh+δTf,tWfh+δTi,tWih+δTc̃,tWch(式8)(54)(55)
根據
δo,t,δf,t,δi,t,δc̃,t 的定義, 可知:
δTo,tδTf,tδTi,tδTc̃,t=δTt∘tanh(ct)∘ot∘(1−ot)(式9)=δTt∘ot∘(1−tanh(ct)2)∘ct−1∘ft∘(1−ft)(式10)=δTt∘ot∘(1−tanh(ct)2)∘c̃t∘it∘(1−it)(式11)=δTt∘ot∘(1−tanh(ct)2)∘it∘(1−c̃2)(式12)(56)(57)(58)(59)
式8到
式12就是將誤差沿時間反向傳播一個時刻的公式. 有了它, 便可以寫出將誤差項傳遞到任意k時刻的公式:
δTk=∏j=kt−1δTo,jWoh+δTf,jWfh+δTi,jWih+δTc̃,jWch(式13)
將誤差項傳遞到上一層
假設當前是第l 層, 定義l−1 層的誤差項是誤差函數對l−1 層加權輸入的導數, 即:
δl−1t=def∂Enetl−1t
本次LSTM的輸入
xt 由下面的公式計算:
xlt=fl−1(netl−1t)
上式中,
fl−1 表示第
l−1 的
激活函數.
因爲netlf,t,netli,t,netlc̃,t,netlo,t 都是xt 的函數, xt 又是netl−1t 的函數, 因此, 要求出E 對netl−1t 的導數, 就需要使用全導數公式:
∂E∂netl−1t=∂E∂netlf,t∂netlf,t∂xlt∂xlt∂netl−1t+∂E∂netli,t∂netli,t∂xlt∂xlt∂netl−1t+∂E∂netlc̃,t∂netlc̃,t∂xlt∂xlt∂netl−1t+∂E∂netlo,t∂netlo,t∂xlt∂xlt∂netl−1t=δTf,tWfx∘f′(netl−1t)+δTi,tWix∘f′(netl−1t)+δTc̃,tWcx∘f′(netl−1t)+δTo,tWox∘f′(netl−1t)=(δTf,tWfx+δTi,tWix+δTc̃,tWcx+δTo,tWox)∘f′(netl−1t)(式14)(60)(61)(62)(63)
式14就是將誤差傳遞到上一層的公式.
權重梯度的計算
對於Wfh,Wih,Wch,Woh 的權重梯度, 我們知道它的梯度是各個時刻梯度之和. 我們首先求出它們在t時刻的梯度, 然後再求出他們最終的梯度.
我們已經求得了誤差項δo,t,δf,t,δi,t,δc̃,t , 很容易求出t時刻的Woh,Wfh,Wih,Wch :
∂E∂Woh,t∂E∂Wfh,t∂E∂Wih,t∂E∂Wch,t=∂E∂neto,t∂neto,t∂Woh,t=δo,thTt−1=∂E∂netf,t∂netf,t∂Wfh,t=δf,thTt−1=∂E∂neti,t∂neti,t∂Wih,t=δi,thTt−1=∂E∂netc̃,t∂netc̃,t∂Wch,t=δc̃,thTt−1(64)(65)(66)(67)(68)(69)(70)(71)(72)(73)(74)
將各個時刻的梯度加在一起, 就能得到最終的梯度:
∂E∂Woh∂E∂Wfh∂E∂Wih∂E∂Wch=∑j=1tδo,jhTj−1=∑j=1tδf,jhTj−1=∑j=1tδi,jhTj−1=∑j=1tδc̃,jhTj−1(75)(76)(77)(78)
對於偏置項
bf,bi,bc,bo 的梯度, 先求出各個時刻的偏置項梯度:
∂E∂bo,t∂E∂bf,t∂E∂bi,t∂E∂bc,t=∂E∂neto,t∂neto,t∂bo,t=δo,t=∂E∂netf,t∂netf,t∂bf,t=δf,t=∂E∂neti,t∂neti,t∂bi,t=δi,t=∂E∂netc̃,t∂netc̃,t∂bc,t=δc̃,t(79)(80)(81)(82)(83)(84)(85)(86)(87)(88)(89)
將各個時刻的偏置項梯度加在一起:
∂E∂bo∂E∂bi∂E∂bf∂E∂bc=∑j=1tδo,j=∑j=1tδi,j=∑j=1tδf,j=∑j=1tδc̃,j(90)(91)(92)(93)
對於
Wfx,Wix,Wcx,Wox 的權重梯度, 只需要根據相應的誤差項直接計算即可:
∂E∂Wox∂E∂Wfx∂E∂Wix∂E∂Wcx=∂E∂neto,t∂neto,t∂Wox=δo,txTt=∂E∂netf,t∂netf,t∂Wfx=δf,txTt=∂E∂neti,t∂neti,t∂Wix=δi,txTt=∂E∂netc̃,t∂netc̃,t∂Wcx=δc̃,txTt(94)(95)(96)(97)(98)(99)(100)(101)(102)(103)(104)
以上就是LSTM的訓練算法的全部公式
GRU
上面所述是一種普通的LSTM, 事實上LSTM存在很多變體, GRU就是其中一種最成功的變體. 它對LSTM做了很多簡化, 同時保持和LSTM相同的效果.
GRU對LSTM做了兩大改動:
- 將輸入門, 遺忘門, 輸出門變爲兩個門: 更新門(Update Gate) zt 和充值門(Reset Gate) rt .
- 將單元狀態與輸出合併爲一個狀態: h
GRU的前向計算公式爲:
ztrth̃th=σ(Wz⋅[ht−1,xt])=σ(Wr⋅[ht−1,xt])=tanh(W⋅[rt∘ht−1,xt])=(1−zt)∘ht−1+zt∘h̃t(105)(106)(107)(108)
下圖是GRU的示意圖: