(1)BN的作用
從上圖可以看出,Sigmoid函數在[-2,2]區間導數值在[0.1,0.25],當輸入大於2或者小於2時,導數逼近於0,從而容易出現梯度彌散的現象。通過標準化後,輸入值被映射在0附近區域,此處的導數不會太小,不會容易出現梯度彌散的現象。
如上圖所示的損失函數等高線圖可知,當x1和x2分佈相近時,收斂更加快速,優化軌跡更好。
結論:通過標準化後,輸入值被映射在0附近區域,此處的導數不會太小,不會容易出現梯度彌散的現象;網絡層輸入分佈相近,收斂速度更快。
(2)如何保證輸入的分佈相近?
其中,m爲Batch樣本數,Batch內部的均值和方差分別爲是計算出來的。
是爲了防止出現除0的錯誤而設置的較小的數,例如le-8。爲了提高BN層的表達能力,引入了縮放和平移。
參數由反向傳播算法自動優化,實現網絡層按需要縮放和平移數據的分佈的目的。
(3)前向傳播
訓練過程:
計算當前Batch的,計算BN層的輸出見公式(1)
迭代更新全局訓練數據的統計值的過程見(2)
其中,momentum是需要設置的一個超參數,用於平衡更新幅度。
Momentum=0時,直接被更新爲最後一個batch的;
Momentum=1時,保持不變。
在tensorflow中,Momentum的默認設置爲0.99。
測試過程:
其中,均來自訓練過程統計或優化,在測試過程中直接使用,並不會更新。
(4)反向更新
在訓練過程中,反向傳播算法根據損失L求解梯度,按照更新法則自動優化。
注意:對於2D的特徵輸入X:[b,h,w,c],BN層不是計算每一個點的,而是在通道C上面統計每個通道上面的所有數據的。
除了C軸上面統計數據的方式,還有如下幾種:
Layer Norm:統計每個樣本的所有特徵的均值和方差
Instance Norm:統計每個樣本的每個通道上特徵的均值和方差。
Group Norm:將通道分成若干組,統計每個樣本的通道組內的特徵均值和方差。
(5)BN層
創建BN層:layer=layers.BatchNormalization()
由於BN在訓練和測試過程的行爲不同,需要通過設置training標誌來區分。
參考資料:Tensorflow 深度學習 龍龍老師