BN(BatchNorm)的理解

論文名字:Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
論文地址:https://arxiv.org/abs/1502.03167

一、研究意義

近年來,隨機梯度下降成了訓練深度網絡的主流方法。儘管隨機梯度下降法對於訓練深度網絡簡單高效,但是它有個毛病,就是需要我們人爲的去選擇參數,比如學習率、參數初始化、權重衰減係數、Drop out比例等。這些參數的選擇對訓練結果至關重要,以至於我們很多時間都浪費在這些的調參上,所以BN算法應運而生。

BN算法(Batch Normalization)其強大之處如下:

  1. 你可以選擇比較大的初始學習率,讓你的訓練速度飆漲。以前還需要慢慢調整學習率,甚至在網絡訓練到一半的時候,還需要想着學習率進一步調小的比例選擇多少比較合適,現在我們可以採用初始很大的學習率,然後學習率的衰減速度也很大,因爲這個算法收斂很快。當然這個算法即使你選擇了較小的學習率,也比以前的收斂速度快,因爲它具有快速訓練收斂的特性;

  2. 你再也不用去理會過擬閤中drop out、L2正則項參數的選擇問題,採用BN算法後,你可以移除這兩項了參數,或者可以選擇更小的L2正則約束參數了,因爲BN具有提高網絡泛化能力的特性;

  3. 再也不需要使用使用局部響應歸一化層了(局部響應歸一化是Alexnet網絡用到的方法,搞視覺的估計比較熟悉),因爲BN本身就是一個歸一化網絡層;

  4. 可以把訓練數據徹底打亂(防止每批訓練的時候,某一個樣本都經常被挑選到,文獻說這個可以提高1%的精度)。

在機器學習領域中有個很重要的假設:IID獨立同分布假設,就是假設訓練數據和測試數據是滿足相同分佈的,這是通過訓練數據獲得的模型能夠在測試集獲得好的效果的一個基本保障。那BatchNorm的作用是什麼呢?BatchNorm就是在深度神經網絡訓練過程中使得每一層神經網絡的輸入保持相同分佈的。

思考一個問題:爲什麼傳統的神經網絡在訓練開始之前,要對輸入的數據做Normalization?
原因在於神經網絡學習過程本質上是爲了學習數據的分佈,一旦訓練數據與測試數據的分佈不同,那麼網絡的泛化能力也大大降低;另一方面,一旦在mini-batch梯度下降訓練的時候,每批訓練數據的分佈不相同,那麼網絡就要在每次迭代的時候去學習以適應不同的分佈,這樣將會大大降低網絡的訓練速度,這也正是爲什麼我們需要對所有訓練數據做一個Normalization預處理的原因。

二、“Internal Covariate Shift”問題

什麼是“Internal Covariate Shift”?

深度網絡的訓練是一個複雜的過程,只要網絡的前面幾層發生微小的改變,那麼後面幾層就會被累積放大下去。一旦網絡某一層的輸入數據的分佈發生改變,那麼這一層網絡就需要去適應學習這個新的數據分佈,所以如果訓練過程中,訓練數據的分佈一直在發生變化,那麼將會影響網絡的訓練速度。這就是所謂的“Internal Covariate Shift”,Internal指的是深層網絡的隱層,是發生在網絡內部的事情,而不是covariate shift問題只發生在輸入層。

BN的提出

在以前的研究中發現,如果圖像處理對於輸入圖像進行白化操作,那麼神經網絡會較快收斂。白化操作就是將輸入數據分佈變換到0均值,單位方差的正態分佈。

然而白化計算量太大了,很不划算,還有就是白化不是處處可微的,所以在深度學習中,其實很少用到白化。經過白化預處理後,數據滿足條件:a、特徵之間的相關性降低,這個就相當於PCA;b、數據均值、標準差歸一化,也就是使得每一維特徵均值爲0,標準差爲1。如果數據特徵維數比較大,要進行PCA,也就是實現白化的第1個要求,是需要計算特徵向量,計算量非常大,於是爲了簡化計算,作者忽略了第1個要求,僅僅使用了下面的公式進行預處理,也就是近似白化預處理:
在這裏插入圖片描述
公式簡單粗糙,但是依舊很牛逼。因此後面我們也將用這個公式,對某一個層網絡的輸入數據做一個歸一化處理。需要注意的是,我們訓練過程中採用batch 隨機梯度下降,上面的E(xk)E(x^k)指的是每一批訓練數據神經元xkx^k的平均值;然後分母就是每一批數據神經元xkx^k激活度的一個標準差了。

三、 BN算法的基本思想

BN的基本思想其實相當直觀:因爲深層神經網絡在做非線性變換前的激活輸入值(就是那個x=WU+B,U是輸入)隨着網絡深度加深或者在訓練過程中,其分佈逐漸發生偏移或者變動,之所以訓練收斂慢,一般是整體分佈逐漸往非線性函數的取值區間的上下限兩端靠近(對於Sigmoid函數來說,意味着激活輸入值WU+B是大的負值或正值),所以這導致反向傳播時低層神經網絡的梯度消失,這是訓練深層神經網絡收斂越來越慢的本質原因,而BN就是通過一定的規範化手段,把每層神經網絡任意神經元這個輸入值的分佈強行拉回到均值爲0方差爲1的標準正態分佈,其實就是把越來越偏的分佈強制拉回比較標準的分佈,這樣使得激活輸入值落在非線性函數對輸入比較敏感的區域,這樣輸入的小變化就會導致損失函數較大的變化,意思是這樣讓梯度變大,避免梯度消失問題產生,而且梯度變大意味着學習收斂速度快,能大大加快訓練速度

其實一句話就是:對於每個隱層神經元,把逐漸向非線性函數映射後向取值區間極限飽和區靠攏的輸入分佈強制拉回到均值爲0方差爲1的比較標準的正態分佈,使得非線性變換函數的輸入值落入對輸入比較敏感的區域,以此避免梯度消失問題。因爲梯度一直都能保持比較大的狀態,所以很明顯對神經網絡的參數調整效率比較高,就是變動大,就是說向損失函數最優值邁動的步子大,也就是說收斂地快。BN說到底就是這麼個機制,方法很簡單,道理很深刻。

從上面幾個圖應該看出來BN在幹什麼了吧?其實就是把隱層神經元激活輸入x=WU+B從變化不拘一格的正態分佈通過BN操作拉回到了均值爲0,方差爲1的正態分佈,即原始正態分佈中心左移或者右移到以0爲均值,拉伸或者縮減形態形成以1爲方差的圖形。什麼意思?就是說經過BN後,目前大部分Activation的值落入非線性函數的線性區內,其對應的導數遠離導數飽和區,這樣來加速訓練收斂過程。

經過前面簡單介紹,這個時候可能我們會想當然的以爲:好像很簡單的樣子,不就是在網絡中間層數據做一個歸一化處理嘛,這麼簡單的想法,爲什麼之前沒人用呢?然而其實實現起來並不是那麼簡單的。其實如果是僅僅使用上面的歸一化公式,對網絡某一層A的輸出數據做歸一化,然後送入網絡下一層B,這樣是會影響到本層網絡A所學習到的特徵的。打個比方,比如我網絡中間某一層學習到特徵數據本身就分佈在S型激活函數的兩側,你強制把它給我歸一化處理、標準差也限制在了1,把數據變換成分佈於s函數的中間部分,這樣就相當於我這一層網絡所學習到的特徵分佈被你搞壞了,這可怎麼辦?於是文獻使出了一招:變換重構,引入了可學習參數γ、β,這就是算法關鍵之處:
在這裏插入圖片描述
每一個神經元xkx^k都會有一對這樣的參數γ、β。這樣其實當:
在這裏插入圖片描述
是可以恢復出原始的某一層所學到的特徵的。因此我們引入了這個可學習重構參數γ、β,讓我們的網絡可以學習恢復出原始網絡所要學習的特徵分佈。最後Batch Normalization網絡層的前向傳導過程公式就是:
在這裏插入圖片描述
通過上邊的公式可以看出BN的計算流程是:

  1. 計算樣本均值。
  2. 計算樣本方差。
  3. 樣本數據標準化處理。
  4. 進行平移和縮放處理。引入了γ和β兩個參數。來訓練γ和β兩個參數。引入了這個可學習重構參數γ、β,讓我們的網絡可以學習恢復出原始網絡所要學習的特徵分佈。
    在反向傳播的時候,通過鏈式求導方式,求出γ與β以及相關權值
    在這裏插入圖片描述

訓練細節:

網絡訓練中以batch_size爲最小單位不斷迭代,很顯然,新的batch_size進入網絡,由於每一次的batch有差異,實際是通過變量,以及滑動平均來記錄均值與方差。訓練完成後,推斷階段時通過γ, β,以及記錄的均值與方差計算bn層輸出。

結合論文中給出的使用過程進行解釋
在這裏插入圖片描述
輸入:待進入激活函數的變量
輸出:
1.對於K個激活函數前的輸入,所以需要K個循環。每個循環中按照上面所介紹的方法計算均值與方差。通過γ,β與輸入x的變換求出BN層輸出。
2.在反向傳播時利用γ與β求得梯度從而改變訓練權值(變量)。
3.通過不斷迭代直到訓練結束,得到γ與β,以及記錄的均值方差。
4.在預測的正向傳播時,使用訓練時最後得到的γ與β,以及均值與方差的無偏估計,通過圖中11:所表示的公式計算BN層輸出。

四、BN inference

BN在訓練的時候可以根據Mini-Batch裏的若干訓練實例進行激活數值調整,但是在推理(inference)的過程中,很明顯輸入就只有一個實例,看不到Mini-Batch其它實例,那麼這時候怎麼對輸入做BN呢?因爲很明顯一個實例是沒法算實例集合求出的均值和方差的。這可如何是好?既然沒有從Mini-Batch數據裏可以得到的統計量,那就想其它辦法來獲得這個統計量,就是均值和方差。可以用從所有訓練實例中獲得的統計量來代替Mini-Batch裏面m個訓練實例獲得的均值和方差統計量,因爲本來就打算用全局的統計量,只是因爲計算量等太大所以纔會用Mini-Batch這種簡化方式的,那麼在推理的時候直接用全局統計量即可。

決定了獲得統計量的數據範圍,那麼接下來的問題是如何獲得均值和方差的問題。很簡單,因爲每次做Mini-Batch訓練時,都會有那個Mini-Batch裏m個訓練實例獲得的均值和方差,現在要全局統計量,只要把每個Mini-Batch的均值和方差統計量記住,然後對這些均值和方差求其對應的數學期望即可得出全局統計量。

五、BN優點

  1. 不僅僅極大提升了訓練速度,收斂過程大大加快;
  2. 提高網絡的泛化能力,增加分類效果,一種解釋是這是類似於Dropout的一種防止過擬合的正則化表達方式,所以不用Dropout也能達到相當的效果;
  3. BN層本質上是一個歸一化網絡層,可以替代局部響應歸一化層(LRN層)。
  4. 另外調參過程也簡單多了,對於初始化要求沒那麼高,而且可以使用大的學習率等。
  5. 可以打亂樣本訓練順序(這樣就不可能出現同一張照片被多次選擇用來訓練)論文中提到可以提高1%的精度。

六、爲什麼BN層一般用在線性層和卷積層後面,而不是放在非線性單元后?

原文中是這樣解釋的,因爲非線性單元的輸出分佈形狀會在訓練過程中變化,歸一化無法消除他的方差偏移,相反的,全連接和卷積層的輸出一般是一個對稱,非稀疏的一個分佈,更加類似高斯分佈,對他們進行歸一化會產生更加穩定的分佈。其實想想也是的,像relu這樣的激活函數,如果你輸入的數據是一個高斯分佈,經過他變換出來的數據能是一個什麼形狀?小於0的被抑制了,也就是分佈小於0的部分直接變成0了,這樣不是很高斯了。

參考文獻:

https://blog.csdn.net/donkey_1993/article/details/81871132
https://blog.csdn.net/Fate_fjh/article/details/53375881
https://www.cnblogs.com/eilearn/p/9780696.html
https://blog.csdn.net/hjimce/article/details/50866313
https://blog.csdn.net/m0_37699976/article/details/81584101

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章