Batch Normalization原理與實戰

 

https://zhuanlan.zhihu.com/p/34879333

https://zhuanlan.zhihu.com/p/34879333

https://zhuanlan.zhihu.com/p/34879333

 

 

前言

本期專欄主要來從理論與實戰視角對深度學習中的Batch Normalization的思路進行講解、歸納和總結,並輔以代碼讓小夥伴兒們對Batch Normalization的作用有更加直觀的瞭解。

本文主要分爲兩大部分。第一部分是理論板塊,主要從背景、算法、效果等角度對Batch Normalization進行詳解;第二部分是實戰板塊,主要以MNIST數據集作爲整個代碼測試的數據,通過比較加入Batch Normalization前後網絡的性能來讓大家對Batch Normalization的作用與效果有更加直觀的感知。


(一)理論板塊

理論板塊將從以下四個方面對Batch Normalization進行詳解:

  • 提出背景
  • BN算法思想
  • 測試階段如何使用BN
  • BN的優勢

理論部分主要參考2015年Google的Sergey Ioffe與Christian Szegedy的論文內容,並輔以吳恩達Coursera課程與其它博主的資料。所有參考內容鏈接均見於文章最後參考鏈接部分。

1 提出背景

1.1 煉丹的困擾

在深度學習中,由於問題的複雜性,我們往往會使用較深層數的網絡進行訓練,相信很多煉丹的朋友都對調參的困難有所體會,尤其是對深層神經網絡的訓練調參更是困難且複雜。在這個過程中,我們需要去嘗試不同的學習率、初始化參數方法(例如Xavier初始化)等方式來幫助我們的模型加速收斂。深度神經網絡之所以如此難訓練,其中一個重要原因就是網絡中層與層之間存在高度的關聯性與耦合性。下圖是一個多層的神經網絡,層與層之間採用全連接的方式進行連接。

 

我們規定左側爲神經網絡的底層,右側爲神經網絡的上層。那麼網絡中層與層之間的關聯性會導致如下的狀況:隨着訓練的進行,網絡中的參數也隨着梯度下降在不停更新。一方面,當底層網絡中參數發生微弱變化時,由於每一層中的線性變換與非線性激活映射,這些微弱變化隨着網絡層數的加深而被放大(類似蝴蝶效應);另一方面,參數的變化導致每一層的輸入分佈會發生改變,進而上層的網絡需要不停地去適應這些分佈變化,使得我們的模型訓練變得困難。上述這一現象叫做Internal Covariate Shift。

1.2 什麼是Internal Covariate Shift

Batch Normalization的原論文作者給了Internal Covariate Shift一個較規範的定義:在深層網絡訓練的過程中,由於網絡中參數變化而引起內部結點數據分佈發生變化的這一過程被稱作Internal Covariate Shift。

這句話該怎麼理解呢?我們同樣以1.1中的圖爲例,我們定義每一層的線性變換爲 ,其中  代表層數;非線性變換爲  ,其中  爲第  層的激活函數。

隨着梯度下降的進行,每一層的參數  與  都會被更新,那麼  的分佈也就發生了改變,進而  也同樣出現分佈的改變。而  作爲第  層的輸入,意味着  層就需要去不停適應這種數據分佈的變化,這一過程就被叫做Internal Covariate Shift。

1.3 Internal Covariate Shift會帶來什麼問題?

(1)上層網絡需要不停調整來適應輸入數據分佈的變化,導致網絡學習速度的降低

我們在上面提到了梯度下降的過程會讓每一層的參數  和  發生變化,進而使得每一層的線性與非線性計算結果分佈產生變化。後層網絡就要不停地去適應這種分佈變化,這個時候就會使得整個網絡的學習速率過慢。

(2)網絡的訓練過程容易陷入梯度飽和區,減緩網絡收斂速度

當我們在神經網絡中採用飽和激活函數(saturated activation function)時,例如sigmoid,tanh激活函數,很容易使得模型訓練陷入梯度飽和區(saturated regime)。隨着模型訓練的進行,我們的參數  會逐漸更新並變大,此時  就會隨之變大,並且  還受到更底層網絡參數  的影響,隨着網絡層數的加深,  很容易陷入梯度飽和區,此時梯度會變得很小甚至接近於0,參數的更新速度就會減慢,進而就會放慢網絡的收斂速度。

對於激活函數梯度飽和問題,有兩種解決思路。第一種就是更爲非飽和性激活函數,例如線性整流函數ReLU可以在一定程度上解決訓練進入梯度飽和區的問題。另一種思路是,我們可以讓激活函數的輸入分佈保持在一個穩定狀態來儘可能避免它們陷入梯度飽和區,這也就是Normalization的思路。

1.4 我們如何減緩Internal Covariate Shift?

要緩解ICS的問題,就要明白它產生的原因。ICS產生的原因是由於參數更新帶來的網絡中每一層輸入值分佈的改變,並且隨着網絡層數的加深而變得更加嚴重,因此我們可以通過固定每一層網絡輸入值的分佈來對減緩ICS問題。

(1)白化(Whitening)

白化(Whitening)是機器學習裏面常用的一種規範化數據分佈的方法,主要是PCA白化與ZCA白化。白化是對輸入數據分佈進行變換,進而達到以下兩個目的:

  • 使得輸入特徵分佈具有相同的均值與方差。其中PCA白化保證了所有特徵分佈均值爲0,方差爲1;而ZCA白化則保證了所有特徵分佈均值爲0,方差相同;
  • 去除特徵之間的相關性。

通過白化操作,我們可以減緩ICS的問題,進而固定了每一層網絡輸入分佈,加速網絡訓練過程的收斂(LeCun et al.,1998b;Wiesler&Ney,2011)。

(2)Batch Normalization提出

既然白化可以解決這個問題,爲什麼我們還要提出別的解決辦法?當然是現有的方法具有一定的缺陷,白化主要有以下兩個問題:

  • 白化過程計算成本太高,並且在每一輪訓練中的每一層我們都需要做如此高成本計算的白化操作;
  • 白化過程由於改變了網絡每一層的分佈,因而改變了網絡層中本身數據的表達能力。底層網絡學習到的參數信息會被白化操作丟失掉。

既然有了上面兩個問題,那我們的解決思路就很簡單,一方面,我們提出的normalization方法要能夠簡化計算過程;另一方面又需要經過規範化處理後讓數據儘可能保留原始的表達能力。於是就有了簡化+改進版的白化——Batch Normalization。

2 Batch Normalization

2.1 思路

既然白化計算過程比較複雜,那我們就簡化一點,比如我們可以嘗試單獨對每個特徵進行normalizaiton就可以了,讓每個特徵都有均值爲0,方差爲1的分佈就OK。

另一個問題,既然白化操作減弱了網絡中每一層輸入數據表達能力,那我就再加個線性變換操作,讓這些數據再能夠儘可能恢復本身的表達能力就好了。

因此,基於上面兩個解決問題的思路,作者提出了Batch Normalization,下一部分來具體講解這個算法步驟。

2.2 算法

在深度學習中,由於採用full batch的訓練方式對內存要求較大,且每一輪訓練時間過長;我們一般都會採用對數據做劃分,用mini-batch對網絡進行訓練。因此,Batch Normalization也就在mini-batch的基礎上進行計算。

 

 

4 Batch Normalization的優勢   重要

Batch Normalization在實際工程中被證明了能夠緩解神經網絡難以訓練的問題,BN具有的有事可以總結爲以下三點:

(1)BN使得網絡中每層輸入數據的分佈相對穩定,加速模型學習速度

BN通過規範化與線性變換使得每一層網絡的輸入數據的均值與方差都在一定範圍內,使得後一層網絡不必不斷去適應底層網絡中輸入的變化,從而實現了網絡中層與層之間的解耦,允許每一層進行獨立學習,有利於提高整個神經網絡的學習速度。

(2)BN使得模型對網絡中的參數不那麼敏感,簡化調參過程,使得網絡學習更加穩定

在神經網絡中,我們經常會謹慎地採用一些權重初始化方法(例如Xavier)或者合適的學習率來保證網絡穩定訓練。

當學習率設置太高時,會使得參數更新步伐過大,容易出現震盪和不收斂。但是使用BN的網絡將不會受到參數數值大小的影響。

我們可以看到,經過BN操作以後,權重的縮放值會被“抹去”,因此保證了輸入數據分佈穩定在一定範圍內。另外,權重的縮放並不會影響到對  的梯度計算;並且當權重越大時,即  越大,  越小,意味着權重  的梯度反而越小,這樣BN就保證了梯度不會依賴於參數的scale,使得參數的更新處在更加穩定的狀態。

(3)BN允許網絡使用飽和性激活函數(例如sigmoid,tanh等),緩解梯度消失問題

在不使用BN層的時候,由於網絡的深度與複雜性,很容易使得底層網絡變化累積到上層網絡中,導致模型的訓練很容易進入到激活函數的梯度飽和區;通過normalize操作可以讓激活函數的輸入數據落在梯度非飽和區,緩解梯度消失的問題;另外通過自適應學習  與  又讓數據保留更多的原始信息。

(4)BN具有一定的正則化效果

在Batch Normalization中,由於我們使用mini-batch的均值與方差作爲對整體訓練樣本均值與方差的估計,儘管每一個batch中的數據都是從總體樣本中抽樣得到,但不同mini-batch的均值與方差會有所不同,這就爲網絡的學習過程中增加了隨機噪音,與Dropout通過關閉神經元給網絡訓練帶來噪音類似,在一定程度上對模型起到了正則化的效果。

另外,原作者通過也證明了網絡加入BN後,可以丟棄Dropout,模型也同樣具有很好的泛化效果。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章