Tensorflow2.0學習筆記(七)BatchNorm層

(1)BN的作用

從上圖可以看出,Sigmoid函數在[-2,2]區間導數值在[0.1,0.25],當輸入大於2或者小於2時,導數逼近於0,從而容易出現梯度彌散的現象。通過標準化後,輸入值被映射在0附近區域,此處的導數不會太小,不會容易出現梯度彌散的現象。

如上圖所示的損失函數等高線圖可知,當x1和x2分佈相近時,收斂更加快速,優化軌跡更好。

結論:通過標準化後,輸入值被映射在0附近區域,此處的導數不會太小,不會容易出現梯度彌散的現象;網絡層輸入分佈相近,收斂速度更快。

(2)如何保證輸入的分佈相近?

其中,m爲Batch樣本數,Batch內部的均值和方差分別爲是計算出來的。

是爲了防止出現除0的錯誤而設置的較小的數,例如le-8。爲了提高BN層的表達能力,引入了縮放和平移。

參數反向傳播算法自動優化,實現網絡層按需要縮放和平移數據的分佈的目的。

(3)前向傳播

訓練過程:

計算當前Batch的,計算BN層的輸出見公式(1)

迭代更新全局訓練數據的統計值的過程見(2)

其中,momentum是需要設置的一個超參數,用於平衡更新幅度。

Momentum=0時,直接被更新爲最後一個batch的

Momentum=1時,保持不變。

在tensorflow中,Momentum的默認設置爲0.99。

測試過程:

其中,均來自訓練過程統計或優化,在測試過程中直接使用,並不會更新。

(4)反向更新

在訓練過程中,反向傳播算法根據損失L求解梯度,按照更新法則自動優化

注意:對於2D的特徵輸入X:[b,h,w,c],BN層不是計算每一個點的而是在通道C上面統計每個通道上面的所有數據的

除了C軸上面統計數據的方式,還有如下幾種:

Layer Norm:統計每個樣本的所有特徵的均值和方差

Instance Norm:統計每個樣本的每個通道上特徵的均值和方差。

Group Norm:將通道分成若干組,統計每個樣本的通道組內的特徵均值和方差。

(5)BN層

創建BN層:layer=layers.BatchNormalization()

由於BN在訓練和測試過程的行爲不同,需要通過設置training標誌來區分。

 

 

參考資料:Tensorflow 深度學習  龍龍老師

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章