神經網絡15分鐘入門!——反向傳播到底是怎麼傳播的?

上一篇神經網絡15分鐘快速入門!足夠通俗易懂了吧文章中對兩層神經網絡進行了描述,從中我們知道神經網絡的過程就是正向傳播得到Loss值,再把Loss值反向傳播,並對神經網絡的參數進行更新。其中反向傳播正是神經網絡的要點所在。

本篇將對反向傳播的內容進行講解,力求通俗,畢竟只有15分鐘時間~

一、鏈式法則

在講反向傳播之前先講一下鏈式法則。

假設一個場景,一輛汽車20萬元,要收10%的購置稅,如果要買2輛,則正向傳播的過程可以畫成:

正向傳播

汽車單價20萬,最終需要支付44萬,我現在想知道汽車單價每波動1萬,對最終支付價格的影響是多少。參看下圖:我們從右向左依次求導,得到的值分別爲

①44/44=1

②44/40=1.1

③40/20=2

那麼最終價格相對於汽車單價的導數就是①×②×③=2.2

這就是鏈式法則。我們只需要知道每個節點導數值,然後求乘積就可以了。

鏈式法則的一種定義*是:

如果某個函數由複合函數表示,則該複合函數的導數可以用構成複合函數的各個函數的導數的乘積表示。

所以我們只需要關注每個節點的導數值即可。

反向傳播

2、反向傳播

下邊介紹幾種典型節點的反向傳播算法。

2.1 加法節點

如下圖:該節點可以寫作z=x+y

加法節點

很容易知道,z對x求導等於1,對y求導也等於1,所以在加法節點反向傳遞時,輸入的值會原封不動地流入下一個節點。

比如:

 

加法節點的反向傳播示例

2.2 乘法節點

如下圖,該節點可以寫作z=x*y

乘法節點

同樣很容易知道,z對x求導等於y,對y求導等於x,所以在加法節點反向傳遞時,輸入的值交叉相乘然後流入下一個節點。

比如:

乘法節點的反向傳播示例

2.3 仿射變換

所謂仿射變換就是這個式子,如果覺得眼生就去看上一篇文章

仿射變換公式

畫成圖的話就是:

仿射變換

這是神經網絡裏的一個重要形式單元。這個圖片看起來雖然複雜,但其實和乘法節點是類似的,我們對X求導,結果就是W1;對W1求導,結果就是X,到這裏和乘法節點是一樣的;對b1求導,結果爲1,原封不動地流入即可。不過需要注意的一點是,這裏的相乘是向量之間的乘法。

2.4 ReLU層

激活層我們就以ReLU爲例。回憶一下,ReLU層的形式是這樣的:

因爲當x>0時,y=x,求導爲1,也就是原封不動傳遞。

當x<=0時,y=0,求導爲0,也就是傳遞值爲0。

2.5 Softmax-with-Loss

Softmax-with-Loss指的就是Softmax和交叉熵損失的合稱。這是我們之前提到的神經網絡的最後一個環節。這部分的反向傳播推導過程比較複雜,這裏直接上結論吧(對推導過程感興趣的話可以看文末參考文獻*的附錄A):

 

該圖來自於參考文獻*

其中

從前面的層輸入的是(a1, a2, a3),softmax層輸出(y1, y2, y3)。此外,教師標籤是(t1, t2, t3),Cross Entropy Error層輸出損失L。

所謂教師標籤,就是表示是否分類正確的標籤,比如正確分類應該是第一行的結果時,(t1, t2, t3)就是(1,0,0)。

從上圖可以看出,Softmax-with-Loss的反向傳播的結果爲(y1 − t1, y2 − t2, y3 − t3)。

3、參數更新

參數的更新對象其實就是W和b,具體的在2.3中對其更新方法進行了描述,簡單來說,dW就是輸入值乘以X,db就等於輸入值。這裏用dW和db表示反向傳播到W和b節點時的計算結果。

那現在該怎樣更新W和b呢?

直接用W=W-dW;b=b-db麼?

可以,但不太好。

 

其一,需要引入正則化懲罰項。這是爲了避免最後求出的W過於集中所設置的項,比如[1/3,1/3,1/3]和[1,0,0],這兩個結果明顯前一個結果更爲分散,也是我們更想要的。爲了衡量分散度,我們用1/2W^2來表示。對該式求導,結果就是W。設正則化懲罰項的係數值爲reg,那麼修正後的dW可以寫爲:

其二,是步子邁的有點大。直接反向傳播回來的量值可能會比較大,在尋找最優解的過程中可能會直接將最優解越過去,所以在這裏設置一個參數:學習率。這個數通常很小,比如設學習率爲0.0001這樣。我們將學習率用epsilon表示,那麼最終更新後的W和b寫爲:

至此,一次反向傳播的流程就走完了。

 

總結

鏈式法則是反向傳播的基本傳遞方式,它大大簡化了反向傳播計算的複雜程度。在本例中可能還不太明顯,在有些非常複雜的網絡中,它的好處會更加顯而易見。

另外反向傳播的各個節點的算法也是比較重要的內容,本文介紹了常用的節點的反向傳播計算結果,實際應用中可能會有更多的形式。不過不用擔心,google一下,你就知道。

參數更新是反向傳播的目的,結合例子來看可能會更容易理解。下一篇文章會不使用任何框架,純手寫一個我們之前提到的神經網絡,並實現象限分類的問題。


另外,公衆號本來的流程應該是先介紹信號時頻域分析方法以及信號特徵提取方法,以及介紹一些統計和隨機過程的知識,最後纔講到神經網絡和深度學習。這篇文章算是給神經網絡系列提前開了個頭。好在神經網絡與其他方法的介紹並不衝突,就作爲另一條主線一同爲大家講解和分享吧。

歡迎持續關注!

下一篇真的很快就能來了

 

 

 

 

參考:

*《深度學習入門:基於Python的理論與實現》

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章