深入理解批標準化(Batch Normalization)

0 前言

  • Batch Normalization作爲最近一年來DL的重要成果,已經廣泛被證明其有效性和重要性。雖然有些細節處理還解釋不清其理論原因,但是實踐證明好用纔是真的好,別忘了DL從Hinton對深層網絡做Pre-Train開始就是一個經驗領先於理論分析的偏經驗的一門學問。
  • 本文參考《Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》.
  • I.I.D問題
    機器學習領域有個很重要的假設:IID獨立同分布假設,就是假設訓練數據和測試數據是滿足相同分佈的,這是通過訓練數據獲得的模型能夠在測試集獲得好的效果的一個基本保障。
  • BatchNorm的作用是什麼呢?
    BatchNorm就是在深度神經網絡訓練過程中使得每一層神經網絡的輸入保持相同分佈的。
  • 引入深度學習中常見的問題
    爲什麼深度神經網絡隨着網絡深度加深,訓練起來越困難,收斂越來越慢?很多論文都是解決這個問題的,比如ReLU激活函數,再比如Residual Network,BN本質上也是解釋並從某個不同的角度來解決這個問題的。

1 “Internal Covariate Shift”問題

1.1 什麼是“Internal Covariate Shift”

從論文名字可以看出,BN是用來解決“Internal Covariate Shift”問題的,那麼首先得理解什麼是“Internal Covariate Shift”?

論文首先說明Mini-Batch SGD相對於One Example SGD的兩個優勢:
梯度更新方向更準確;
並行計算速度快;
(爲什麼要說這些?因爲BatchNorm是基於Mini-Batch SGD的,所以先誇下Mini-Batch SGD,當然也是大實話);然後吐槽下SGD訓練的缺點:超參數調起來很麻煩。(作者隱含意思是用BN就能解決很多SGD的缺點)

  • 引入“Internal Covariate Shift”的概念(重點閱讀)
    如果ML系統實例集合<X,Y>中的輸入值X的分佈老是變,這不符合IID假設,網絡模型很難穩定的學規律,這不得引入遷移學習才能搞定嗎,我們的ML系統還得去學習怎麼迎合這種分佈變化啊。對於深度學習這種包含很多隱層的網絡結構,在訓練過程中,因爲各層參數不停在變化,所以每個隱層都會面臨covariate shift的問題,也就是在訓練過程中,隱層的輸入分佈老是變來變去,這就是所謂的“Internal Covariate Shift”,Internal指的是深層網絡的隱層,是發生在網絡內部的事情,而不是covariate shift問題只發生在輸入層。

  • 舉一個分佈不一致的例子(BY Andrew Ng)
    這裏輸入的分佈不一致(左邊和右邊)
    在這裏插入圖片描述

  • 引入Batch Norm
    能不能讓每個隱層節點的激活輸入分佈固定下來呢?這樣就避免了“Internal Covariate Shift”問題了。

  • BN的啓發來源
    啓發來源的:之前的研究表明如果在圖像處理中對輸入圖像進行白化(Whiten)操作的話——所謂白化,就是對輸入數據分佈變換到0均值,單位方差的正態分佈——那麼神經網絡會較快收斂,那麼BN作者就開始推論了:圖像是深度神經網絡的輸入層,做白化能加快收斂,那麼其實對於深度網絡來說,其中某個隱層的神經元是下一層的輸入,意思是其實深度神經網絡的每一個隱層都是輸入層,不過是相對下一層來說而已,那麼能不能對每個隱層都做白化呢?這就是啓發BN產生的原初想法,而BN也確實就是這麼做的,可以理解爲對深層神經網絡每個隱層神經元的激活值做簡化版本的白化操作。

2 Batch Norm的本質思想

2.1 本質思想

  • BN的基本思想其實相當直觀:因爲深層神經網絡在做非線性變換前的激活輸入值(就是那個x=WU+B,U是輸入)隨着網絡深度加深或者在訓練過程中,其分佈逐漸發生偏移或者變動,之所以訓練收斂慢,一般是整體分佈逐漸往非線性函數的取值區間的上下限兩端靠近(對於Sigmoid函數來說,意味着激活輸入值WU+B是大的負值或正值),所以這導致反向傳播時低層神經網絡的梯度消失,這是訓練深層神經網絡收斂越來越慢的本質原因,而BN就是通過一定的規範化手段,把每層神經網絡任意神經元這個輸入值的分佈強行拉回到均值爲0方差爲1的標準正態分佈,其實就是把越來越偏的分佈強制拉回比較標準的分佈,這樣使得激活輸入值落在非線性函數對輸入比較敏感的區域,這樣輸入的小變化就會導致損失函數較大的變化,意思是這樣讓梯度變大,避免梯度消失問題產生,而且梯度變大意味着學習收斂速度快,能大大加快訓練速度。
  • 總結一下
    對於每個隱層神經元,把逐漸向非線性函數映射後向取值區間極限飽和區靠攏的輸入分佈強制拉回到均值爲0方差爲1的比較標準的正態分佈,使得非線性變換函數的輸入值落入對輸入比較敏感的區域,以此避免梯度消失問題。因爲梯度一直都能保持比較大的狀態,所以很明顯對神經網絡的參數調整效率比較高,就是變動大,就是說向損失函數最優值邁動的步子大,也就是說收斂地快。BN說到底就是這麼個機制,方法很簡單,道理很深刻。

2.2 將激活輸入調整爲N(0,1)有何用?

在這裏插入圖片描述
這意味着在一個標準差範圍內,也就是說64%的概率x其值落在[-1,1]的範圍內,在兩個標準差範圍內,也就是說95%的概率x其值落在了[-2,2]的範圍內。那麼這又意味着什麼?我們知道,激活值x=WU+B,U是真正的輸入,x是某個神經元的激活值,假設非線性函數是sigmoid,那麼看下sigmoid(x)其圖形:
在這裏插入圖片描述
及sigmoid(x)的導數爲:G’=f(x)*(1-f(x)),因爲f(x)=sigmoid(x)在0到1之間,所以G’在0到0.25之間,其對應的圖如下:
在這裏插入圖片描述

  • 假設沒有經過BN調整前x的原先正態分佈均值是-6,方差是1,那麼意味着95%的值落在了[-8,-4]之間,那麼對應的Sigmoid(x)函數的值明顯接近於0,這是典型的梯度飽和區,在這個區域裏梯度變化很慢,爲什麼是梯度飽和區?請看下sigmoid(x)如果取值接近0或者接近於1的時候對應導數函數取值,接近於0,意味着梯度變化很小甚至消失。而假設經過BN後,均值是0,方差是1,那麼意味着95%的x值落在了[-2,2]區間內,很明顯這一段是sigmoid(x)函數接近於線性變換的區域,意味着x的小變化會導致非線性函數值較大的變化,也即是梯度變化較大,對應導數函數圖中明顯大於0的區域,就是梯度非飽和區。
  • 從上面幾個圖應該看出來BN在幹什麼了吧?其實就是把隱層神經元激活輸入x=WU+B從變化不拘一格的正態分佈通過BN操作拉回到了均值爲0,方差爲1的正態分佈,即原始正態分佈中心左移或者右移到以0爲均值,拉伸或者縮減形態形成以1爲方差的圖形。什麼意思?就是說經過BN後,目前大部分Activation的值落入非線性函數的線性區內,其對應的導數遠離導數飽和區,這樣來加速訓練收斂過程

2.3 存在的一個問題

如果都通過BN,那麼不就跟把非線性函數替換成線性函數效果相同了?這意味着什麼?我們知道,如果是多層的線性函數變換其實這個深層是沒有意義的,因爲多層線性網絡跟一層線性網絡是等價的。這意味着網絡的表達能力下降了,這也意味着深度的意義就沒有了。所以BN爲了保證非線性的獲得,對變換後的滿足均值爲0方差爲1的x又進行了scale加上shift操作(y=scale*x+shift),每個神經元增加了兩個參數scaleshift參數,這兩個參數是通過訓練學習到的,意思是通過scale和shift把這個值從標準正態分佈左移或者右移一點並長胖一點或者變瘦一點,每個實例挪動的程度不一樣,這樣等價於非線性函數的值從正中心周圍的線性區往非線性區動了動。核心思想應該是想找到一個線性和非線性的較好平衡點,既能享受非線性的較強表達能力的好處,又避免太靠非線性區兩頭使得網絡收斂速度太慢。當然,這是我的理解,論文作者並未明確這樣說。但是很明顯這裏的scale和shift操作是會有爭議的,因爲按照論文作者論文裏寫的理想狀態,就會又通過scale和shift操作把變換後的x調整回未變換的狀態,那不是饒了一圈又繞回去原始的“Internal Covariate Shift”問題裏去了嗎,感覺論文作者並未能夠清楚地解釋scale和shift操作的理論原因。

3 Batch Norm的訓練過程

3.1 再Mini-batch SGD下做BN操作

  • BN操作在哪裏進行?
    在未經過激活函數的輸入和通過激活函數的中間放一個BN層。(這裏一些論文有一些爭議,就是有的人會把BN層放在激活函數之後使用,但是一般的,都是放在中間)
  • BN操作
    在這裏插入圖片描述
    對於Mini-Batch SGD來說,一次訓練過程裏面包含m個訓練實例,其具體BN操作就是對於隱層內每個神經元的激活值來說,進行如下變換:
    在這裏插入圖片描述
    變換的意思是:某個神經元對應的原始的激活x通過減去mini-Batch內m個實例獲得的m個激活x求得的均值E(x)併除以求得的方差Var(x)來進行轉換。
  • 加入scale和shift超參數
    上文說過經過這個變換後某個神經元的激活x形成了均值爲0,方差爲1的正態分佈,目的是把值往後續要進行的非線性變換的線性區拉動,增大導數值,增強反向傳播信息流動性,加快訓練收斂速度。但是這樣會導致網絡表達能力下降,爲了防止這一點,每個神經元增加兩個調節參數(scale和shift),這兩個參數是通過訓練來學習到的,用來對變換後的激活反變換,使得網絡表達能力增強,即對變換後的激活進行如下的scale和shift操作,這其實是變換的反操作
    在這裏插入圖片描述
  • 具體算法
    在這裏插入圖片描述
  • 代碼
    tf.keras.layers.BatchNormalization()

4 Batch Norm的預測過程

4.1 預測中存在的問題

  • 在預測時,當我們的輸入只有一個實例時,應如何選取方差和均值?

4.2 解決該問題

4.2.1 方法一

利用全局統計量

  • 很簡單,因爲每次做Mini-Batch訓練時,都會有那個Mini-Batch裏m個訓練實例獲得的均值和方差,現在要全局統計量,只要把每個Mini-Batch的均值和方差統計量記住,然後對這些均值和方差求其對應的數學期望即可得出全局統計量,即:
    在這裏插入圖片描述
  • 有了均值和方差,每個隱層神經元也已經有對應訓練好的Scaling參數和Shift參數,就可以在推導的時候對每個神經元的激活數據計算BN進行變換了,:
    在這裏插入圖片描述

4.2.2 方法二(常用)

指數加權平均(在實踐中常用這個方法)

  • 訓練時,利用不同batch中對應的量進行指數加權平均計算我們的統計量均值和方差。
  • 在深度學習框架裏一般會有默認的方法,不需要調整。
    在這裏插入圖片描述

參考文獻

參考文獻1

Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章