摘要:
在前面的文章裏面,RNN訓練與BP算法,我們提到了RNN的訓練算法。但是回頭看的時候在時間的維度上沒有做處理,所以整個推導可能存在一點問題。
那麼,在這篇文章裏面,我們將介紹bptt(Back Propagation Through Time)算法如在訓練RNN。
關於bptt
這裏首先解釋一下所謂的bptt,bptt的思路其實很簡單,就是把整個RNN按時間的維度展開成一個“多層的神經網絡”。具體來說比如下圖:
既然RNN已經按時間的維度展開成一個看起來像多層的神經網絡,這個時候用普通的bp算法就可以同樣的計算,只不過這裏比較複雜的是權重共享。比如上圖中每一根線就是一個權重,而我們可以看到在RNN由於權重是共享的,所以三條紅線的權重是一樣的,這在運用鏈式法則的時候稍微比較複雜。
正文:
首先,和以往一樣,我們先做一些定義。
hti=f(netthi)
netthi=∑m(vimxtm)+∑s(uisht−1s)
nettyk=∑mwkmhtm
最後一層經過softmax的轉化
otk=enettyk∑k′enettyk′
在這裏我們使用交叉熵作爲Loss Function
Et=−∑kztklnotk
我們的任務同樣也是求∂E∂wkm 、∂E∂vim 、∂E∂uim 。
注意,這裏的E 沒有時間的下標。因爲在RNN裏,這些梯度分別爲各個時刻的梯度之和。
即:
∂E∂wkm=∑stept=0∂Et∂wkm
∂E∂vim=∑stept=0∂Et∂vim
∂E∂uim=∑stept=0∂Et∂uim 。
所以下面我們推導的是∂Et∂wkm 、∂Et∂vim 、∂Et∂uim 。
我們先推導∂Et∂wkm 。
∂Et∂wkm=∑k′∂Et∂otk′∂otk′∂nettyk∂nettyk∂wkm=(otk−ztk)∗htm 。(這一部分的推導在前面的文章已經討論過了)。
在這裏,記誤差信號:
δ(output,t)k=∂Et∂nettyk=∑k′∂Et∂otk′∂otk′∂nettyk=(otk−ztk) (後面會用到)
對於∂Et∂vim 、∂Et∂uim 其實是差不多的,所以這裏詳細介紹其中一個。這兩個導數也是RNN裏面最複雜的。
推導:∂Et∂vim
∂Et∂vim=∑tt′=0∂Et∂nett′hi∂nett′hi∂vim
對於這個式子第一次看可能有點懵逼,這裏稍微解釋一下:
從式:hti=f(∑m(vimxtm)+∑s(uisht−1s)) 中我們可以看到,vim 影響的是所有時刻的netthi,t=0,1,2,....step 。所以當Et 對vim 求偏導的時候,由於鏈式法則需要考慮到所有時刻的netthi 。
下面分成兩部分來求∂Et∂nett′hi ,∂nett′hi∂vim. 。
第一部分:∂Et∂nett′hi 。
這裏我們記δ(t′,t)i=∂Et∂nett′hi (誤差信號,和前面文章一樣)。
(由於帶着符號去求這兩個導數會讓人看起來非常懵逼,所以下面指定具體的值,後面抽象給出通式)
假設共3個時刻,即t=0,1,2。
對於t=2 ,t′=2 時:
(E2 表示第2個時刻(也是最後一個時刻)的誤差)
(net2hi 表示第2個時刻隱藏層第i個神經元的淨輸入)
具體來說:∂E2∂net2hi=∂E2∂h2i∂h2i∂net2hi
對於∂E2∂h2i=∑k′∂E2∂net2yk′∂net2yk′∂h2i
由於δ(output,t)k=∂Et∂nettyk
所以,我們有:
∂E2∂h2i=∑k′∂E2∂net2yk′∂net2yk′∂h2i=∑k′δ(output,2)k′∂net2yk′∂h2i=∑k′δ(output,2)k′wk′i
綜上:
δ(2,2)i=∂E2∂net2hi=∂E2∂h2i∂h2i∂net2hi=(∑k′δ(output,2)k′wk′i)∗f′(net2hi)
對於t=1 ,t′=2 時:
(E2 表示第2個時刻的誤差)
(net1hi 表示第1個時刻隱藏層第i個神經元的淨輸入)
具體來說:∂E2∂net1hi=∂E2∂h1i∂h1i∂net1hi
那麼∂E2∂h1i=∑k′∂E2∂net1yk′∂net1yk′∂h1i+∑j∂E2∂net2hj∂net2hj∂h1i 。請對比這個式子和上面t=2 ,t′=2 時的區別,區別在於多了一項∑j∂E2∂net2hj∂net2hj∂h1i 。這個原因我們已經在RNN與bp算法中討論過,這裏簡單的說就是由於t=1 時刻有t=2 時刻反向傳播回來的誤差,所以要考慮上這一項,但是對於t=2 已經是最後一個時刻了,沒有反向傳播回來的誤差。
對於第一項∑k′∂E2∂net1yk′∂net1yk′∂h1i 其實是0。下面簡單分析下原因:
上式進一步可以化爲:∑k′(∑k″∂E2∂o1k″∂o1k″∂net1yk′)∂net1yk′∂h1i 而E2 與第1個時刻輸出o1k″ 無關。所以爲0。
對於第二項∑j∂E2∂net2hj∂net2hj∂h1i ,我們帶入δ(t′,t)i=∂Et∂nett′hi 有:
∑j∂E2∂net2hj∂net2hj∂h1i=∑jδ(2,2)j∂net2hj∂h1i 。
同時明顯有∂net2hj∂h1i=uji
即:∂E2∂h1i=∑jδ(2,2)juji
綜上:
δ(1,2)i=∂E2∂net1hi=∂E2∂h1i∂h1i∂net1hi=(∑jδ(2,2)j∂net2hj∂h1i)∗f′(net1hi)=(∑jδ(2,2)juji)∗f′(net1hi)
對於t=0 ,t′=2 時:
(E2 表示第2個時刻的誤差)
(net0hi 表示第0個時刻隱藏層第i個神經元的淨輸入)。
和上面的思路一樣,我們容易得到:
δ(0,2)i=∂E2∂net0hi=(∑jδ(1,2)juji)∗f′(net0hi) 。
至此,我們求完了∂Et∂nett′hi 。下面我們來總結一下其通式:
∂Et∂nett′hi=δ(t′,t)i={(∑k′δ(output,t)k′wk′i)∗f′(nett′hi),(∑jδ(t′+1,t)juji)∗f′(nett′hi),t=t′t≠t′
另外,對於δ(output,t)k 有以下表達式:
δ(output,t)k=∂Et∂nettyk=∑k′∂Et∂otk′∂otk′∂nettyk=(otk−ztk)
最後只要求出∂nett′hi∂vim ,其值具體爲∂nett′hi∂vim=xtm
最後,對於∂Et∂uim 其實和上面的差不多,主要是後面的部分不一樣,具體來說:
∂Et∂uim=∑tt′=0∂Et∂nett′hi∂nett′hi∂uim ,可以看到就只有等式右邊的第二項不一樣,關鍵部分是一樣的。∂nett′hi∂uim=ht′−1m
細節-1
上面提到,當只有3個時刻時,t=0,1,2。
對於誤差E2 (最後一個時刻的誤差),沒有再下一個時刻反向傳回的誤差。
那麼對於E1 (第1個時刻的誤差)存在下一個時刻反向傳回的誤差,但是在∂E1∂h1i 中的第二項∑j∂E1∂net2hj∂net2hj∂h1i 仍然爲0。是因爲∂E1∂net2hj=0 ,因爲E1 的誤差和下一個時刻隱藏層的輸出沒有任何關係。
總結
看起來bptt和我們之前討論的bp本質上是一樣的,只是在一些細節的處理上由於權重共享的原因有所不同,但是基本上還是一樣的。
下面這篇文章是有一個簡單的rnn代碼,大家可以參考一下
參考文章1
代碼的bptt中每一步的迭代公式其實就是上面的公式。希望對大家有幫助~