BatchNorm, 批規範化,主要用於解決協方差偏移問題,主要分三部分:
- 計算batch均值和方差
- 規範化
- 仿射affine
算法內容如下:
需要說明幾點:
- 均值和方差是batch的統計特性,pytorch中用running_mean和running_var表示
- $\gamma \(和\)\beta$是可學習的參數,分別是affine中的weight和bias
以BatchNorm2d爲例,分析其中變量和參數的意義:
-
affine: 仿射的開關,決定是否使用仿射這個過程。
- affine=False則\(\gamma=1,\beta=0\) ,並且不能學習和更新。
- affine=True則以上兩者都可以更新
-
training:模型爲訓練狀態和測試狀態時的運行邏輯是不同的。
-
track_running_stats: 決定是否跟蹤整個訓練過程中的batch的統計特性,而不僅僅是當前batch的特性。
-
num_batches_tracked:如果設置track_running_stats爲真,這個就會起作用,代表跟蹤的batch個數,即統計了多少個batch的特性。
-
momentum: 滑動平均計算running_mean和running_var
\(\hat{x}_{\text {new }}=(1-\) momentum \() \times \hat{x}+\) momentum \(\times x_{t}\)
class _NormBase(Module):
"""Common base of _InstanceNorm and _BatchNorm"""
_version = 2
__constants__ = ['track_running_stats', 'momentum', 'eps',
'num_features', 'affine']
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
track_running_stats=True):
super(_NormBase, self).__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
if self.affine:
self.weight = Parameter(torch.Tensor(num_features))
self.bias = Parameter(torch.Tensor(num_features))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
if self.track_running_stats:
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
else:
self.register_parameter('running_mean', None)
self.register_parameter('running_var', None)
self.register_parameter('num_batches_tracked', None)
self.reset_parameters()
training和tracking_running_stats有四種組合:
- training=True,tracking_running_stats=True: 這是正常的訓練過程,BN跟蹤的對象是整個訓練過程的batch特性。
- training=True,tracking_running_stats=False: BN不會跟蹤整個訓練過程的batch特性,而只是計算當前batch的統計特性。
- training=False,tracking_running_stats=True: 正常的測試過程,BN會用之前訓練好的running_mean和running_var,並且不會對其進行更新。(ps: 這就是有時候爲何有一些NAS算法會使用BN校正技術,即在訓練集上運行幾個batch,更新running_mean和running_var)
- training=False,tracking_running_stats=False: 一般不採用這種,只計算當前測試batch統計特性,容易造成統計特性偏移,對結果造成不好的結果。
更新過程:
- running_mean和running_var是在forward過程中更新的,記錄在buffer中(即不可通過反向傳播算法影響的部分)
- \(\alpha, \gamma\)是在反向傳播中更新的。
- 在蒸餾過程中,需要注意教師模型需要設置model.eval()來固定running_mean和running_var,否則會不發生變化,對結果造成不確定的影響。
參考文獻: