RNN-bptt簡單推導

摘要:


在前面的文章裏面,RNN訓練與BP算法,我們提到了RNN的訓練算法。但是回頭看的時候在時間的維度上沒有做處理,所以整個推導可能存在一點問題。

那麼,在這篇文章裏面,我們將介紹bptt(Back Propagation Through Time)算法如在訓練RNN。

關於bptt


這裏首先解釋一下所謂的bptt,bptt的思路其實很簡單,就是把整個RNN按時間的維度展開成一個“多層的神經網絡”。具體來說比如下圖:
這裏寫圖片描述

既然RNN已經按時間的維度展開成一個看起來像多層的神經網絡,這個時候用普通的bp算法就可以同樣的計算,只不過這裏比較複雜的是權重共享。比如上圖中每一根線就是一個權重,而我們可以看到在RNN由於權重是共享的,所以三條紅線的權重是一樣的,這在運用鏈式法則的時候稍微比較複雜。

正文:


首先,和以往一樣,我們先做一些定義。
hti=f(netthi)

netthi=m(vimxtm)+s(uisht1s)

nettyk=mwkmhtm
最後一層經過softmax的轉化
otk=enettykkenettyk
在這裏我們使用交叉熵作爲Loss Function
Et=kztklnotk

我們的任務同樣也是求EwkmEvimEuim
注意,這裏的E 沒有時間的下標。因爲在RNN裏,這些梯度分別爲各個時刻的梯度之和。
即:
Ewkm=stept=0Etwkm
Evim=stept=0Etvim
Euim=stept=0Etuim

所以下面我們推導的是EtwkmEtvimEtuim

我們先推導Etwkm
Etwkm=kEtotkotknettyknettykwkm=(otkztk)htm 。(這一部分的推導在前面的文章已經討論過了)。
在這裏,記誤差信號:
δ(output,t)k=Etnettyk=kEtotkotknettyk=(otkztk) (後面會用到)

對於EtvimEtuim 其實是差不多的,所以這裏詳細介紹其中一個。這兩個導數也是RNN裏面最複雜的。

推導:Etvim

Etvim=tt=0Etnetthinetthivim
對於這個式子第一次看可能有點懵逼,這裏稍微解釋一下:
從式:hti=f(m(vimxtm)+s(uisht1s)) 中我們可以看到,vim 影響的是所有時刻的netthi,t=0,1,2,....step 。所以當Etvim 求偏導的時候,由於鏈式法則需要考慮到所有時刻的netthi

下面分成兩部分來求Etnetthinetthivim.
第一部分:Etnetthi
這裏我們記δ(t,t)i=Etnetthi (誤差信號,和前面文章一樣)。



(由於帶着符號去求這兩個導數會讓人看起來非常懵逼,所以下面指定具體的值,後面抽象給出通式)
假設共3個時刻,即t=0,1,2。
對於t=2t=2 時:
E2 表示第2個時刻(也是最後一個時刻)的誤差)
net2hi 表示第2個時刻隱藏層第i個神經元的淨輸入)
具體來說:E2net2hi=E2h2ih2inet2hi

對於E2h2i=kE2net2yknet2ykh2i
由於δ(output,t)k=Etnettyk
所以,我們有:
E2h2i=kE2net2yknet2ykh2i=kδ(output,2)knet2ykh2i=kδ(output,2)kwki
綜上:
δ(2,2)i=E2net2hi=E2h2ih2inet2hi=(kδ(output,2)kwki)f(net2hi)

對於t=1t=2 時:
E2 表示第2個時刻的誤差)
net1hi 表示第1個時刻隱藏層第i個神經元的淨輸入)
具體來說:E2net1hi=E2h1ih1inet1hi
那麼E2h1i=kE2net1yknet1ykh1i+jE2net2hjnet2hjh1i 。請對比這個式子和上面t=2t=2 時的區別,區別在於多了一項jE2net2hjnet2hjh1i 。這個原因我們已經在RNN與bp算法中討論過,這裏簡單的說就是由於t=1 時刻有t=2 時刻反向傳播回來的誤差,所以要考慮上這一項,但是對於t=2 已經是最後一個時刻了,沒有反向傳播回來的誤差。

對於第一項kE2net1yknet1ykh1i 其實是0。下面簡單分析下原因:
上式進一步可以化爲:k(kE2o1ko1knet1yk)net1ykh1iE2 與第1個時刻輸出o1k 無關。所以爲0。

對於第二項jE2net2hjnet2hjh1i ,我們帶入δ(t,t)i=Etnetthi 有:
jE2net2hjnet2hjh1i=jδ(2,2)jnet2hjh1i
同時明顯有net2hjh1i=uji
即:E2h1i=jδ(2,2)juji

綜上:
δ(1,2)i=E2net1hi=E2h1ih1inet1hi=(jδ(2,2)jnet2hjh1i)f(net1hi)=(jδ(2,2)juji)f(net1hi)

對於t=0t=2 時:
E2 表示第2個時刻的誤差)
net0hi 表示第0個時刻隱藏層第i個神經元的淨輸入)。
和上面的思路一樣,我們容易得到:
δ(0,2)i=E2net0hi=(jδ(1,2)juji)f(net0hi)

至此,我們求完了Etnetthi 。下面我們來總結一下其通式:

Etnetthi=δ(t,t)i={(kδ(output,t)kwki)f(netthi),(jδ(t+1,t)juji)f(netthi),t=ttt


另外,對於δ(output,t)k 有以下表達式:
δ(output,t)k=Etnettyk=kEtotkotknettyk=(otkztk)



最後只要求出netthivim ,其值具體爲netthivim=xtm


最後,對於Etuim 其實和上面的差不多,主要是後面的部分不一樣,具體來說:
Etuim=tt=0Etnetthinetthiuim ,可以看到就只有等式右邊的第二項不一樣,關鍵部分是一樣的。netthiuim=ht1m

細節-1


上面提到,當只有3個時刻時,t=0,1,2。
對於誤差E2 (最後一個時刻的誤差),沒有再下一個時刻反向傳回的誤差。
那麼對於E1 (第1個時刻的誤差)存在下一個時刻反向傳回的誤差,但是在E1h1i 中的第二項jE1net2hjnet2hjh1i 仍然爲0。是因爲E1net2hj=0 ,因爲E1 的誤差和下一個時刻隱藏層的輸出沒有任何關係。

總結


看起來bptt和我們之前討論的bp本質上是一樣的,只是在一些細節的處理上由於權重共享的原因有所不同,但是基本上還是一樣的。

下面這篇文章是有一個簡單的rnn代碼,大家可以參考一下
參考文章1
代碼的bptt中每一步的迭代公式其實就是上面的公式。希望對大家有幫助~

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章