文章目錄
深度神經網絡訓練過程中,各網絡層參數在不斷變化,每層網絡的輸入分佈不斷變化 ,不同的輸入分佈可能需重新訓練,此外,我們也不得不使用 較小的參數初始化、較小的學習率 訓練模型,避免網絡輸出陷入飽和區,造成BP算法的梯度消失,深度模型一般難以訓練。
作者稱這種內部網絡層輸入分佈變化的現象爲“Internal Covariate Shift”,作者提出Batch Normalization, BN算法解決Internal Covariate Shift和輸出飽和問題 ,使訓練速度加快,算法的特點是:
- 允許使用較大的學習率訓練模型,允許使用較大的參數初始化模型,經過BN處理後的輸入,基本不會處於激活函數飽和區
- 自帶正則化,與dropout類似,它爲每個隱藏層的輸入增加了噪聲,這意味着我們可使用更少的dropout
- 模型收斂速度顯著加快,ImageNet數據集分類任務使用BN,以7%的訓練步數達到相同的表現
爲什麼要使用BN?
使用mini-batch SGD訓練模型的好處?在batch使用SGD是深度網絡模型訓練的主流方法,如Adagrad等,這種方法是模型在每個step在batch上的損失更新模型參數,隨着batch size增加,訓練效果也會增加,此外,通過batch方式訓練,也可以利用到現代的計算平臺實現並行計算,如GPU的核心數爲2的n次方,將batch size設置爲2的n次方,往往可以獲得更快的訓練速度。
使用SGD訓練模型存在的問題?由於每層網絡的輸入受前面所有層網絡的參數影響,這可能導致前層網絡參數較小變化,多數神經元位於非飽和區(梯度爆炸),後層網絡變化很大,又或者前層網絡較大變化,某些神經元陷入飽和區(梯度消失),後層網絡變化很小(前面波濤洶涌,後面波瀾不驚)。
爲什麼要規範化網絡層輸入分佈?當每層網絡輸入分佈變化,意味着每層網絡要不斷地適應新的分佈,這是一種“covariate shift”現象,其實這個協變量移位的概念可以擴展至整個學習系統之外,如子網絡或者神經網絡中的一層:
式中可以是任意變換,參數通過最小化損失學習得到。
如果將看作爲子網絡的輸入,則參數的過程可看作爲(batch size=,learning rate=)
這完全等價於以作爲輸入訓練獨立的網絡,我們都知道 訓練集和測試集具有相同分佈是模型可有效訓練的前提,其實這也適用於訓練子網絡或子層,如果能保持每次訓練時,輸入的分佈相同,則在學習參數時,就不必補償因分佈變化帶來的影響。
如何解決深層網絡訓練的梯度消失問題?考慮sigmoid激活函數
隨着輸入絕對值的增大,梯度趨於0,意味着的所有維度,除非該維度絕度值較小,否則該維度流向的梯度將消失,模型訓練減緩。由於輸入受前面所有層參數的影響,改變前面層參數,很小可能將的許多維度推向飽和區,這種現象在深層網絡中尤其明顯,在實踐過程中,使用 ReLU激活函數並小心地初始化參數,可以較好的避免飽和問題。
如何使得網絡層的每次輸入分佈一致且避免輸出飽和,從而加速訓練?基於此,提出 Batch Normalization 方法,將輸入分佈規範化爲標準正太分佈,解決以上問題。
怎樣使用BN?
參照圖像處理領域以 白化(whitening, 像素值線性變化至0均值、1方差) 加速訓練的方式,Batch Normalization對神經網絡的每一層或一些層的輸入執行 “白化” 操作,固定輸入分佈,從而降低Internal Covariate Shift的影響。
白化在數據預處理過程中的目的:
- 去除特徵間相關性,使得特徵獨立;
- 使得所有特徵具有相同的均值和方差,特徵同分布;
對包含規範化的網絡層使用SGD訓練會出現什麼問題?對於包含Batch Norm的優化過程,優化器在使用SGD更新模型參數時,可能會以一種要求規範化更新(去規範化) 的形式更新參數,從而儘可能地加快訓練過程!
舉例來說,對於一個簡單的網絡層,輸入爲,偏差爲,以下列方式得到規範化後輸入:
假如期望和偏差無依賴關係,則偏差的更新公式爲
因此,
從更新前後結果可以看到,值的更新不會影響網絡輸出,當然也不會降低總體損失! 當我們在梯度下降過程之外進行參數規範化時,值可能無限制增長,導致模型blows up! 發生這一問題的原因在於,優化器未考慮到某些參數進行了規範化,也就是說損失函數中不包含某些參數的規範化項。
如何讓優化器考慮到所有參數的規範化?如果對於任意參數,網絡總是能以期望的分佈產生激活值,也就是說 損失函數對模型參數的梯度應考慮模型參數的規範化,或者說模型參數的規範化依賴於模型參數。
令向量表示網絡層輸入,是訓練樣本集,規範化輸入可表示爲
如果是其它網絡層的輸出,則規劃範項不僅依賴於,而且依賴於中的每一個訓練樣本,因爲經過多次迭代,每個訓練樣本都對產生過作用!在使用BP算法優化參數時,需分別計算關於和的雅可比矩陣,如果忽略後項,很可能造成上述 模型爆炸 的現象。計算規範化項對訓練集的雅可比矩陣計算量太大,Batch Norm算法基於整個訓練集的統計信息,對一個訓練樣本進行規範化。
如何有效地實現BN?
第一種簡單化處理:對輸入向量的每一維獨立地進行標準化(0均值、1方差規範化)
其中,期望和方差是在整個訓練集中計算得到的。
此外,對每層網絡都進行簡單的標準化,可能會降低模型的非線性表達能力,如對sigmoid的輸入標準化,目的是把一些絕對值較大的非線性區值拉回到線性區,增大導數值,因此最終大多數點被約束在[-2, 2]區間,相當於僅利用與sigmoid函數的近似線性部分:
Batch Normlization對規範化後的值在進行縮放和平移解決這個問題:
模型可以學習變量和,以恢復模型的非線性表達能力,每個神經元都有自己獨特的伸縮、平移參數。特別地,當等於標準差、等於期望時,模型等價於沒有進行規範化,這賦予了模型自我控制規範化的自由度! 換一種角度來看,引入這兩個變量之後,使得優化器僅通過修改這兩個參數 去規範化,而不是調整可能會降低網絡穩定性的網絡權重參數。
第二種簡單化處理:在每個mini-batch內部估計每次激活的均值和方差。 可看作爲規範化後的值,易證整個Norm過程可微,因此可以使用BP更新網絡參數。
BN的實現見算法1。
如何訓練和推理使用BN的網絡?
在訓練過程中,若mini-batch的樣本數大於1,我們都可以使用batch norm有效地訓練網絡,但在推理過程中,mini-batch可能僅有一個樣本,因此,使用以下公式規範化輸入
與訓練過程不同的是,上式中期望和方差都是訓練集總體的統計量,若迭代次數爲,則 無偏統計量
推理和訓練過程見算法2。
BN作用在神經元的輸入側還是輸出側?
Batch Normalization 可以應用於網絡的任一組激活 ,我們考慮element-wise非線性變化
爲非線性映射,如ReLU或者sigmiod。
我們可以將BN作用於非線性變換之前,即規範化,同樣 可以將BN作用於,但是很可能是另一非線性函數的輸出,在訓練過程中其分佈形狀可能會發生改變,而且約束其一、二階矩並不能消除協變量轉移???
相比較,更有可能具有對稱性、非稀疏分佈性,即更加具有“高斯特性”,因此規範化此部分,很可能獲得我們想要的穩定分佈。如果我們規範化,則 可以忽略偏差項,因爲它的作用會在均值相減時被抵消,因此網絡層的規範化可表示爲
爲什麼BN網絡可以使用更高的學習率?
傳統深層網絡使用較大的學習率訓練模型,可能 造成梯度消失、梯度爆炸或陷入局部極小點,而BN可阻止輸入進入激活函數的飽和區,從而避免這些問題。
從另一個角度來看,伸縮變化網絡權重參數,不影響BN輸出,對於確定的標量:
梯度爲
爲什麼BN網絡自帶正則化效果?
BM相當於對神經元的輸入進行平移和縮放,而平移和縮放的大小僅在當前mini-batch中計算得到,等價於對神經元的輸入增加了噪音,dropout也有增加噪音的作用。
The stochasticity from the batch statistics serves as a regularizer during training. (From Layer Normalization)
換一種角度來看,神經元的輸入值是經過在mini-batch中標準化後得到的,具有樣本結合性,神經網絡不再需要爲單一訓練樣本生成一個特定值,這種效應有利於模型的泛化,在 BN網絡中Dropout似乎可以減少強度或者移除。
BN在Tensorflow中的實現
主要與原論文的區別在於,每次迭代以滑動平均的方式將mini-batch的均值和方差更新到總體!
以下僅展示BN的實際過程,batch_size=8,hidden_size=3:
import tensorflow as tf
x = tf.reshape(tf.range(24, dtype=tf.float32), shape=(8, 3))
"""
<tf.Tensor: id=1269, shape=(8, 3), dtype=float32, numpy=
array([[ 0., 1., 2.],
[ 3., 4., 5.],
[ 6., 7., 8.],
[ 9., 10., 11.],
[12., 13., 14.],
[15., 16., 17.],
[18., 19., 20.],
[21., 22., 23.]], dtype=float32)>
"""
bn = tf.keras.layers.BatchNormalization()
# training
bn(x, training=True)
"""
<tf.Tensor: id=1397, shape=(8, 3), dtype=float32, numpy=
array([[-1.5275091 , -1.5275091 , -1.5275091 ],
[-1.0910779 , -1.0910779 , -1.0910779 ],
[-0.65464675, -0.65464675, -0.65464675],
[-0.21821558, -0.21821558, -0.21821558],
[ 0.21821558, 0.21821558, 0.21821558],
[ 0.65464675, 0.65464675, 0.65464675],
[ 1.0910779 , 1.0910779 , 1.0910779 ],
[ 1.5275091 , 1.5275091 , 1.5275091 ]], dtype=float32)>
"""
# inference
bn(x)
"""
<tf.Tensor: id=1433, shape=(8, 3), dtype=float32, numpy=
array([[-0.08679464, 0.73155487, 1.5499043 ],
[ 2.3930523 , 3.2114017 , 4.0297513 ],
[ 4.872899 , 5.6912484 , 6.5095983 ],
[ 7.352746 , 8.171096 , 8.989445 ],
[ 9.832593 , 10.650943 , 11.469292 ],
[12.31244 , 13.13079 , 13.949139 ],
[14.792287 , 15.610637 , 16.428986 ],
[17.272135 , 18.090483 , 18.908833 ]], dtype=float32)>
"""
Reference
1. Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
2. 詳解深度學習中的Normalization,BN/LN/WN
3.【深度學習】深入理解Batch Normalization批標準化