【Pytorch基礎】BatchNorm常識梳理與使用

BatchNorm, 批規範化,主要用於解決協方差偏移問題,主要分三部分:

  • 計算batch均值和方差
  • 規範化
  • 仿射affine

算法內容如下:

圖源https://blog.csdn.net/LoseInVain/article/details/86476010

需要說明幾點:

  • 均值和方差是batch的統計特性,pytorch中用running_mean和running_var表示
  • $\gamma \(和\)\beta$是可學習的參數,分別是affine中的weight和bias

以BatchNorm2d爲例,分析其中變量和參數的意義:

  • affine: 仿射的開關,決定是否使用仿射這個過程。

    • affine=False則\(\gamma=1,\beta=0\) ,並且不能學習和更新。
    • affine=True則以上兩者都可以更新
  • training:模型爲訓練狀態和測試狀態時的運行邏輯是不同的。

  • track_running_stats: 決定是否跟蹤整個訓練過程中的batch的統計特性,而不僅僅是當前batch的特性。

  • num_batches_tracked:如果設置track_running_stats爲真,這個就會起作用,代表跟蹤的batch個數,即統計了多少個batch的特性。

  • momentum: 滑動平均計算running_mean和running_var

    \(\hat{x}_{\text {new }}=(1-\) momentum \() \times \hat{x}+\) momentum \(\times x_{t}\)

class _NormBase(Module):
    """Common base of _InstanceNorm and _BatchNorm"""
    _version = 2
    __constants__ = ['track_running_stats', 'momentum', 'eps',
                     'num_features', 'affine']

    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
                 track_running_stats=True):
        super(_NormBase, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats
        if self.affine:
            self.weight = Parameter(torch.Tensor(num_features))
            self.bias = Parameter(torch.Tensor(num_features))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(num_features))
            self.register_buffer('running_var', torch.ones(num_features))
            self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
        else:
            self.register_parameter('running_mean', None)
            self.register_parameter('running_var', None)
            self.register_parameter('num_batches_tracked', None)
        self.reset_parameters()

training和tracking_running_stats有四種組合:

  • training=True,tracking_running_stats=True: 這是正常的訓練過程,BN跟蹤的對象是整個訓練過程的batch特性。
  • training=True,tracking_running_stats=False: BN不會跟蹤整個訓練過程的batch特性,而只是計算當前batch的統計特性。
  • training=False,tracking_running_stats=True: 正常的測試過程,BN會用之前訓練好的running_mean和running_var,並且不會對其進行更新。(ps: 這就是有時候爲何有一些NAS算法會使用BN校正技術,即在訓練集上運行幾個batch,更新running_mean和running_var)
  • training=False,tracking_running_stats=False: 一般不採用這種,只計算當前測試batch統計特性,容易造成統計特性偏移,對結果造成不好的結果。

更新過程:

  • running_mean和running_var是在forward過程中更新的,記錄在buffer中(即不可通過反向傳播算法影響的部分)
  • \(\alpha, \gamma\)是在反向傳播中更新的。
  • 在蒸餾過程中,需要注意教師模型需要設置model.eval()來固定running_mean和running_var,否則會不發生變化,對結果造成不確定的影響。

參考文獻:

https://blog.csdn.net/LoseInVain/article/details/86476010

https://blog.csdn.net/yangwangnndd/article/details/94901175

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章