重溫循環神經網絡(RNN)

1. 寫在前面

最近用深度學習做一些時間序列預測的實驗, 用到了一些循環神經網絡的知識, 而當初學這塊的時候,只是停留在了表面,並沒有深入的學習和研究,只知道大致的原理, 並不知道具體的細節,所以導致現在復現一些經典的神經網絡會有困難, 所以這次藉着這個機會又把RNN, GRU, LSTM以及Attention的一些東西複習了一遍,真的是每一遍學習都會有新的收穫,之前學習過也沒有整理, 所以這次也藉着這個機會把這一塊的基礎內容進行一個整理和總結, 順便了解一下這些結構底層的邏輯。

當然,這次的整理是查缺補漏, 類似於知識的串聯, 一些很基礎的內容可能不會涉及到, 這一部分由於篇幅很長,所以打算用三篇基礎文章來整理, 分別是重溫循環神經網絡RNN重溫LSTM和GRU重溫Seq2Seq與Attention機制。 整理完了這些基礎知識 , 然後會總結一篇用於時間序列預測非線性自迴歸模型的論文,這篇論文用的就是帶有雙階段注意力機制的LSTM,後面也會使用keras嘗試復現並用於時間序列預測的任務,通過這樣的方式,可以把這些基礎知識從理論變成實踐, 這也是先整理三篇基礎文章的原因, 因爲復現過程中發現一些細節懵懵懂懂, 所以還是先溫習遍基礎 😉

第一篇就是重溫RNN, 這裏面會先從全連接神經網絡開始, 看一下神經網絡到底長什麼樣子以及如何進行計算, 然後針對一些特殊的任務說一下全連接神經網絡的侷限性引出循環神經網絡架構, 然後根據這個結構說一些基礎知識和運算細節, 並用numpy簡單實現一下RNN的前向傳播過程, 最後分析一下傳統RNN的侷限性, 通過反向傳播的公式看一下RNN爲什麼會存在梯度消失和爆炸, 而有了梯度消失爲啥又不能捕捉到長期關聯, 如何解決梯度消失問題等。 通過解決方法引出LSTM和變體GRU, 再去探索這兩個的原理和一些實現細節。

大綱如下

  • 理解RNN? 我們先從全連接神經網絡開始
  • 關於RNN結構的基礎知識和計算細節
  • RNN前向傳播的numpy實現與RNN的侷限性
  • 總結

Ok, let’s go!

2. RNN? 我們還是先從神經網絡開始說起吧

說到神經網絡, 我們肯定是不陌生了, 並且也非常熟悉它的運算過程, 拿我整理Pytorch的時候的一張圖再回顧一下神經網絡:
在這裏插入圖片描述
上面其實就是全連接網絡的一個總體計算過程,左上就是一個全連接神經網絡示意圖, 一個全連接神經網絡有一個輸入層, 若干個隱藏層和一個輸出層, 它的計算步驟包括前向傳播, 計算損失, 反向傳播, 更新參數,然後重複這個過程。 具體細節就不再這裏展開了, 這種網絡功能也是非常強大, 由於激活函數的存在,也善於學習很複雜的非線性關係。

但是有些任務, 比如我們的輸入是一個句子: Cat is beautiful! 讓這個神經網絡進行翻譯, 我們一般要這麼做, 首先,會把上面這3個單詞轉成向量的形式,要不然模型不認識, 可以通過one-hot或者embedding等, 然後我們喂入神經網絡, 得到輸出:
在這裏插入圖片描述
應該是一個這樣的過程, 上面這個圖得好好理解一下, 這就是如果基於全連接網絡的話會是這樣的一個圖, 這裏之所以畫成3步,就是爲了後面更好的理解循環神經網絡, 如果看過吳恩達老師的深度學習, 這裏畫的是這樣的一個圖:
在這裏插入圖片描述
這裏也拿來做個對比吧, 這個圖的話很容易把特徵和不同時間步的序列給搞混了, 並且不利於後面和遞歸神經網絡進行對比,所以我把每個單詞的翻譯給分開了, 分別通過神經網絡進行翻譯。

但是上面這種網絡存在一些問題, 很大的一個問題就是單詞和單詞之間的翻譯孤立起來了, 沒有關聯了, 但是我們知道句子翻譯很大程度上是依賴於上下文的, 如果不看上下文, 很容易把某個詞翻譯錯的。比如我前面的cat換成cats, 後面的is就需要換成are, 但是在上面的神經網絡裏面, 是學習不到這種詞與詞之間的關聯關係, 所以這種神經網絡對於這種時序性的任務不擅長, 也就是說如果我的輸入是一串序列,並且這串序列前後之間有關聯關係, 比如一個句子, 一段音樂, 一段語音,一段視頻, 一段隨時間變化的數據(股票,溫度)等這樣的數據, 如果想用一個網絡對這樣的數據進行建模, 比如捕捉這些前後的關聯關係,全連接神經網絡是不行的,什麼? 還有CNN?CNN1D確實可以處理一些簡單的時間序列數據, 但是功能比較受限, 於是循環神經網絡誕生了。

3. 關於RNN結構的基礎知識和計算細節

啥叫循環神經網絡呢? 這裏的循環到底幹什麼事情呢? 下面這個就是循環神經網絡的圖, 通過這個圖很容易看到循環吧, 但是對於初學者來說,這個圖並不是那麼好理解:
在這裏插入圖片描述
其實,雖然這個圖不是那麼好理解, 那還是這個圖能夠真正的表示循環神經網絡,更能看出一種循環, 簡單的說, 循環神經網絡在做一件這樣的事情:
在這裏插入圖片描述

我們的輸入序列不是說有時間的先後關係嗎?我們不是說要捕捉不同時間步中輸入數據的關聯嗎? 看看RNN是如何做的:

我們不妨設t-1, t, t+1三個時刻, 首先神經網絡會接收t-1時刻的輸入Xt1X_{t-1}進行運算, 然後求出隱藏狀態St1S_{t-1}和輸出Ot1O_{t-1}, 計算完畢之後, 會把隱藏狀態的值St1S_{t-1}和t時刻的輸入XtX_{t}同時作爲t時刻的神經網絡的運算輸入, 然後進行計算得到StS_{t}OtO_{t}, 計算完畢之後, 把t時刻的隱態StS_t與t+1時刻的輸入Xt+1X_{t+1}作爲t+1時刻神經網絡的輸入, 計算St+1Ot+1S_{t+1}和O_{t+1}, 這個過程是一氣呵成的, 之所以稱之爲循環,就是因爲它需要在多個時間步中反覆執行這個計算過程, 而後面時間步裏面的計算,需要用到前面時間步中的結果, 通過這種方式去捕捉序列之間的關聯關係。

下面看兩張動圖感受一下這個過程:
第一個過程, 每個時間步接收一個輸入, 並進行計算處理
在這裏插入圖片描述
第二個過程, 前一時間步處理的結果要傳遞到下一個時間步
在這裏插入圖片描述

所以上面這個過程我們可以用下面的公式表示:
Ot=g(VSt+bo)St=f(UXt+WSt1+bs)\begin{array}{l} O_{t}=g\left(V \cdot S_{t}+ b_o\right ) \\ S_{t}=f\left(U \cdot X_{t}+W \cdot S_{t-1}+b_s\right) \end{array}

也就是當前時刻t的隱藏狀態StS_{t}不僅僅取決於當前的輸入XtX_t, 還取決於前一個時刻的隱藏狀態值St1S_{t-1}, 這裏的g,fg,f激活函數了。看下圖可能會更加清楚:
在這裏插入圖片描述
如果是把我上面舉得那個例子拿下來的話,就是這樣的一個感覺
在這裏插入圖片描述
所以這裏要注意一些細節:

  • 不要以爲這是很多個全連接神經網絡,其實這就是一個神經網絡,只不過不同的時間步用了不同的輸入而已。
  • 這裏的前向傳播過程是一氣呵成的, 就是在一個時間步的循環中,直接進行每個時間步的前向傳播,得到最後的結果。
  • 注意這裏的可學習參數W,V,UW, V, U, 不同的時間步裏面都使用的這一套參數, 所以這裏的參數是共享的, 參數共享有很多好處, 比如減少計算量, 比如特徵提取, 也可以讓模型更好的泛化, 比如我去年去了北京, 和去年我去了北京, 這兩個句子意思一樣, 但是文字位置不同,共享的參數有利於學習詞義本身而不是每個位置的規則。
  • 這裏還要注意幾個名詞, 第一個就是timesteps, 表示時間步長, 也就是時間序列的長度, 需要循環迭代的次數, 第二個就是input_dim, 這個表示的每個時間步的輸入數據有多少個特徵, 第三個是units, 這個指的是上面隱藏層有多少個神經單元, 爲什麼要說這三個名詞呢? 因爲在使用實際用RNN或者LSTM的時候,這三個是核心參數,後面整理LSTM的時候,會看看keras的LSTM層如何用,那時候會再次看到這三個名詞

下面我們把上面按照時間線展開的RNN換一種形式表示,就是把那個圓圈給它再放大放大,進來看看細節:
在這裏插入圖片描述
這個就是RNN按照時間線展開的圖了,這裏的符號可能和上面表示的不一樣,這裏我就先不統一符號了,畢竟參考的資料不一樣, 如果真懂了運算原理, 就不會在乎符號的問題, 並且這裏主要也是說明計算原理,上面這個圖是取自吳恩達老師的深度學習課程, 這裏的RNN-cell, 可以理解成那個隱藏層, 裏面當然很多個隱藏單元, 我們可以看一下這裏面的整體計算:
在這裏插入圖片描述
這裏與上面不同的是,指明瞭具體的激活函數g,fg, f了,這個公式和上面循環神經網絡的計算公式一樣, 無非是符號換了一下。

下面我們可以基於上面的這個RNN的運算過程, 用numpy簡單的寫一下。爲了看清楚這個過程, 還找了張動圖:
在這裏插入圖片描述
動圖後面會給出參考鏈接。

4. RNN前向傳播的numpy實現與RNN的侷限性

根據上面的圖, 我們就用numpy代碼簡單實現一下RNN的前向傳播,這樣更容易裏面RNN的前向傳播過程, 首先,依然是先定義上面細節中的三個名詞:timesteps, input_dim和units, 這裏我們假設時間步長是4, input_dim是3, units是5, 然後10個樣本。 實現過程,我們先看看一個RNN-cell裏面的計算, 把上面的圖拿下來:
在這裏插入圖片描述
先實現一個Cell裏面的計算過程, 我們可以先看一下這裏面的輸入有當前時間步的輸入數據xt, 前一時間步的輸入數據a_prev, 然後輸出有a_t, yt_pred, 而計算公式就是上面那個,參數有WaxW_{ax}, 維度是(5, 3), 這個根據input_dim和units確定的,因爲它描述的是輸入和隱藏單元之間的一種映射, WaaW_{aa}, 維度是(5, 5), 這個是units確定, 因爲它描述的是隱藏單元與下一個時間步隱藏單元的映射, WyaW_{ya}, 維度是(2, 5),描述的是輸出與隱藏單元的映射, 所以可以直接定義一個函數, 寫這個計算過程:

def rnn_cell_forward(xt, a_prev, parameters):
	
	# 獲得參數
	Wax = parameters["Wax"]
    Waa = parameters["Waa"]
    Wya = parameters["Wya"]
    ba = parameters["ba"]
    by = parameters["by"]
	
	
	# cell 的前向傳播
	a_t = np.tanh(np.dot(Wax, xt) + np.dot(Waa, a_prev) + ba)
	yt_pred = softmax(np.dot(Wya, a_t) + by)
	
	# 保存一下重要結果
	cache = (a_t, a_prev, xt, parameters)
	
	return a_t, yt_pred, cache

xt = np.random.randn(3,10)
a_prev = np.random.randn(5,10)

# 初始化參數
Waa = np.random.randn(5, 5)
Wax = np.random.randn(5, 3)
Wya = np.random.randn(2, 5)
ba = np.random.randn(5, 1)
by = np.random.randn(2, 1)

a_next, yt_pred, cache = rnn_cell_forward(xt, a_prev, parameters)

這就是cell的前向傳播, 當然這裏面有一些細節, 比如像那些參數, a_next, a_prev, xt這些東西, 最好都保存一下,反向傳播的時候會用到。

一個cell的前向傳播完畢, 那麼整個RNN的前向傳播應該咋寫呢? 還是看圖
在這裏插入圖片描述
有了一個cell的計算, 整個RNN其實就是時間步的一個循環, 所以可以用一個時間步循環解決這個問題, 還是先分析一下, 接收的數據是a0和整個x, 這個x的維度就是(input_dim, m, timesteps), 而a0的維度就是(units, m), 而這裏的輸出有最後一步的a, 這個維度是(units, m, timesteps), y_pred, 維度是(n_y, m, timesteps)。 過程就是遍歷每個時間步, 得到本時間步的輸出y和下一步的輸入a_next, 把這個加入到最後的y和a裏面即可。

def rnn_forward(x, a0, parameters):
	
	caches = []   # 保存結果
	
	# 獲取到那幾個重要的參數
	input_dim, m, T_x = x.shape
	n_y, units = parameters['Wya'].shape
	
	# 初始化a, y_pred
	a = np.zeros((units, m, T_x))
	y_pred = np.zeros((n_y, m, T_x))
	
	a_next = a0

	for t in range(T_x):
		a_next, yt_pred, cache = rnn_cell_forward(x[:,:,t], a_next, parameters)
		a[:, :, t] = a_next
		y_pred[:, :, t] = yt_pred

		caches.append(cache)
	
	caches = (caches, x)
	return a, y_pred, caches

這就是RNN的前向傳播過程, 這樣理解這個循環神經網絡的計算過程爲啥是一氣呵成了吧, 但是這裏還要注意一下, 這個和普通的全連接前向傳播的循環可不一樣, 這裏是只有一層隱藏層, 然後這裏的循環是時間步的循環, 而全連接網絡那裏的循環是多個隱藏層, 循環是隱藏層的循環計算, 如果不理解的話很容易就搞亂了。這裏是一層的RNN, 但是有一個時間步的循環計算, 而普通的一層全連接網絡,是不用循環計算的

那麼這裏又要看一個問題了, 我們知道全連接那分析的時候,如果層數很多, 就會出現梯度消失或者爆炸, 這是因爲在反向傳播的時候, 通過鏈式法則的推導,會用到上一層正向傳播過程中的輸出, 而這個輸出,又依賴於前面層數的輸出,這是一個連乘的計算過程, 所以如果前面某一層某個值很大或者很小的時候,就會導致後面某些層的輸出很小, 這樣就會導致梯度消失或者爆炸, 如果不知道我在說啥的,建議補一下基礎, 或者看看系統學習Pytorch筆記六:模型的權值初始化與損失函數介紹, 這裏面解釋了一點梯度消失和爆炸現象。

而回到RNN, 其實也存在這個現象,爲啥呢? 因爲上面說了, 一層的RNN就會有一個時間步的循環計算, 而這個時間步的長度是依賴於輸入序列的長度的, 如果序列很長很長,那麼這裏也相當於前向傳播有了一個很深的連乘運算, 則RNN的反向傳播過程會隨着時間序列產生長期依賴,這是因爲每一步的隱態StS_t隨着時間序列在前向傳播, 而StS_t又是Wx,WsW_x, W_s的函數, 所以會有一個時間步之間隱態的一個連乘計算, 有連乘,就會出現危險, 如果不明白的話,看個計算過程就明白了, 順便看一下RNN的反向傳播:
在這裏插入圖片描述
就拿一個三個時間步的RNN來看, 通過上面的分析,我們可以寫一下它的前向傳播過程:
S1=WxX1+WsS0+b1O1=WoS1+b2S2=WxX2+WsS1+b1O2=WoS2+b2S3=WxX3+WsS2+b1O3=WoS3+b2\begin{array}{l} S_{1}=W_{x} X_{1}+W_{s} S_{0}+b_{1} \qquad O_{1}=W_{o} S_{1}+b_{2} \\ S_{2}=W_{x} X_{2}+W_{s} S_{1}+b_{1}\qquad O_{2}=W_{o} S_{2}+b_{2} \\ S_{3}=W_{x} X_{3}+W_{s} S_{2}+b_{1} \qquad O_{3}=W_{o} S_{3}+b_{2} \end{array}

假設在t=3時刻, 損失函數爲L3=12(Y3O3)2L_{3}=\frac{1}{2}\left(Y_{3}-O_{3}\right)^{2}

則對於一次訓練任務的損失函數爲L=t=0TLtL=\sum_{t=0}^{T} L_{t}, 就是每個時間步損失函數的一個累加。 那麼我們開始考慮反向傳播的過程, 其實就是對Wx,Ws,Wo,b1,b2W_x, W_s, W_o, b_1, b_2求偏導, 並不斷調整它們使L儘可能達到最小。

那麼我們就對t3時刻的Wx,Ws,WoW_x, W_s, W_o求一下偏導:
L3Wo=L3O3O3Wo=(O3Y3)S3\frac{\partial L_{3}}{\partial W_{o}}=\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial W_{o}}=(O_3-Y_3) S_3
這個我們發現,對於WoW_o求導, 並沒有產生長期依賴。而下面看看對於Wx,WsW_x, W_s求偏導:
L3Wx=L3O3O3S3S3Wx+L3O3O3S3S3S2S2Wx+L3O3O3S3S3S2S2S1S1WxL3Ws=L3O3O3S3S3Ws+L3O3O3S3S3S2S2Ws+L3O3O3S3S3S2S2S1S1Ws\begin{aligned} \frac{\partial L_{3}}{\partial W_{x}}=\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{3}}{\partial W_{x}}+\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{3}}{\partial S_{2}} \frac{\partial S_{2}}{\partial W_{x}}+\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{3}}{\partial S_{2}} \frac{\partial S_{2}}{\partial S_{1}} \frac{\partial S_{1}}{\partial W_{x}} \\ \\ \frac{\partial L_{3}}{\partial W_{s}}=\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{3}}{\partial W_{s}}+\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{3}}{\partial S_{2}} \frac{\partial S_{2}}{\partial W_{s}}+\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{3}}{\partial S_{2}} \frac{\partial S_{2}}{\partial S_{1}} \frac{\partial S_{1}}{\partial W_{s}} \end{aligned}
這兩個就會產生一種時間序列依賴,因爲StS_t隨着時間序列在前向傳播, 而StS_t又是Wx,WsW_x, W_s的函數。 根據上面求偏導的過程, 可以得到任意一個時刻對Wx,WsW_x, W_s求偏導的公式:
LtWx=k=0tLtOtOtSt(j=k1tSjSj1)SkWx\frac{\partial L_{t}}{\partial W_{x}}=\sum_{k=0}^{t} \frac{\partial L_{t}}{\partial O_{t}} \frac{\partial O_{t}}{\partial S_{t}}\left(\prod_{j=k-1}^{t} \frac{\partial S_{j}}{\partial S_{j-1}}\right) \frac{\partial S_{k}}{\partial W_{x}}

WsW_s求偏導也是同理。 如果加上激活函數, Sj=tanh(WxXj+WsSj1+b1)S_{j}=\tanh \left(W_{x} X_{j}+W_{s} S_{j-1}+b_{1}\right), 則
j=k+1tSjSj1=j=k+1ttanhWs\prod_{j=k+1}^{t} \frac{\partial S_{j}}{\partial S_{j-1}}=\prod_{j=k+1}^{t} \tanh ^{\prime} W_{s}

這樣就清晰了, tanh的導數我們知道是小於等於1的, 而這個連乘的大小其實取決於這個WsW_s, 如果WsW_s很大, 那麼在求導過程中就會梯度爆炸, 如果很小,那麼就會出現梯度消失, 所以這就是RNN中梯度消失或者爆炸的原因, 關鍵之處就是這個連乘運算。 注意一下, 之類說的梯度消失,並不是說後面時刻參數更新時梯度爲0, 而是說後面時刻參數更新的時候, 越往前的序列信息對更新起不到作用了。

並且我們假設有個t=20的時候看看Wx,WsW_x, W_s求偏導的公式:
L20Wx=L20O20O20S20S20Wx+L20O20O20S20S20S19S19Wx+L20O20O20S20S20S19S19S18S18Wx+....+0+0+...+0\begin{aligned} \frac{\partial L_{20}}{\partial W_{x}}=\frac{\partial L_{20}}{\partial O_{20}} \frac{\partial O_{20}}{\partial S_{20}} \frac{\partial S_{20}}{\partial W_{x}}+\frac{\partial L_{20}}{\partial O_{20}} \frac{\partial O_{20}}{\partial S_{20}} \frac{\partial S_{20}}{\partial S_{19}} \frac{\partial S_{19}}{\partial W_{x}}+\frac{\partial L_{20}}{\partial O_{20}} \frac{\partial O_{20}}{\partial S_{20}} \frac{\partial S_{20}}{\partial S_{19}} \frac{\partial S_{19}}{\partial S_{18}} \frac{\partial S_{18}}{\partial W_{x}} + ....+ 0 + 0 + ...+ 0\end{aligned}
SkWx=Xk\frac{\partial S_k}{\partial W_x}=X_k, 這個其實也就是再說,如果某時刻距離當前時刻越遠, 比如t=3, 也就是上面加法的後面一長串累乘到出現S3S_3的時候,因爲有了這一長串累乘,很容易導致梯度消失,那麼S3Wx=X3\frac{\partial S_3}{\partial W_x}=X_3不起作用了(因爲累乘那塊是0, 乘以這個梯度也是0), 這也就是說在t=20的時候,t=3時刻的輸入對於t=20時參數更新L20Wx\frac{L_{20}}{W_x}是起不到任何作用的。 這就相當於RNN並沒有辦法捕捉這種長期的依賴關係, 只能捕捉局部的依賴關係, 比如t=20時參數的更新,可能只依賴於X19,X18,X17X_{19},X_{18}, X_{17}這3步的輸入值。

這對應着吳恩達老師講的那個例子: The cat, which ate already, …, was full。 就是後面的was還是were, 要看前面是cat, 還是cats, 但是一旦中間的這個which 句子很長, cat的信息根本傳不到was這裏來。對was的更新沒有任何幫助, 這是RNN一個很大的不足之處。

所以,通過上面的分析, 我們知道了RNN存在着一個很大的問題梯度消失,而RNN出現梯度消失問題之後, 就沒法再捕捉序列之間的長期關聯或者依賴關係

而解決上面這個問題的根本,其實就是j=k+1tSjSj1\prod_{j=k+1}^{t} \frac{\partial S_{j}}{\partial S_{j-1}}, 因爲這個連乘, 纔會有梯度消失或者爆炸現象,進而纔會無法捕捉長期依賴。 那麼如何解決這個問題呢? 那就是讓這個連乘保持一個常量, 這樣的話就不會梯度消失或者爆炸了。 當然RNN是做不到了, 所以LSTM就誕生了。

5. 總結

好了, 這篇基礎知識的內容就整理到這裏, 如果後面加上LSTM就會太多了, 所以趁熱乎快速回顧一下: 這篇文章就是圍繞着時序序列的任務進行展開, 從全連接網絡開始,複習了一下DNN的步驟和處理這種時序序列任務的侷限性, 引出了RNN, 然後重點說了一下RNN的運算原理和幾個細節部分, 糾正一下初學者對RNN的理解誤差, 然後爲了更加詳細的理解RNN的計算原理,用numpy實現了一下前向傳播的過程, 並有一個例子寫了一下反向傳播的公式, 並解釋了一下爲什麼RNN會存在梯度消失和爆炸現象, 爲什麼不能捕捉長期依賴關係, 最後又分析了這兩個問題的解決關鍵在什麼地方。

而RNN的這兩個問題到底是如何解決的呢? 下一篇重溫LSTM及其變體GRU中告訴你 😉

參考

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章