Batch Normalization詳解以及pytorch實驗

Batch Normalization是google團隊在2015年論文《Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》提出的。通過該方法能夠加速網絡的收斂並提升準確率。在網上雖然已經有很多相關文章,但基本都是擺上論文中的公式泛泛而談,bn真正是如何運作的很少有提及。本文主要分爲以下幾個部分:

(1)BN的原理

(2)使用pytorch驗證本文的觀點

(3)使用BN需要注意的地方(BN沒用好就是個坑)

 

1.Batch Normalization原理

我們在圖像預處理過程中通常會對圖像進行標準化處理,這樣能夠加速網絡的收斂,如下圖所示,對於Conv1來說輸入的就是滿足某一分佈的特徵矩陣,但對於Conv2而言輸入的feature map就不一定滿足某一分佈規律了(注意這裏所說滿足某一分佈規律並不是指某一個feature map的數據要滿足分佈規律,理論上是指整個訓練樣本集所對應feature map的數據要滿足分佈規律)。而我們Batch Normalization的目的就是使我們的feature map滿足均值爲0,方差爲1的分佈規律。

看到這裏應該還是蒙的,不要慌,喝口水,慢慢來。下面是從原論文中截取的原話,注意標黃的部分:

“對於一個擁有d維的輸入x,我們將對它的每一個維度進行標準化處理。”  假設我們輸入的x是RGB三通道的彩色圖像,那麼這裏的d就是輸入圖像的channels即d=3,x=(x^{(1)}, x^{(2)}, x^{(3)}),其中x^{(1)}就代表我們的R通道所對應的特徵矩陣,依此類推。標準化處理也就是分別對我們的R通道,G通道,B通道進行處理。上面的公式不用看,原文提供了更加詳細的計算公式:

我們剛剛有說讓feature map滿足某一分佈規律,理論上是指整個訓練樣本集所對應feature map的數據要滿足分佈規律,也就是說要計算出整個訓練集的feature map然後在進行標準化處理,對於一個大型的數據集明顯是不可能的,所以論文中說的是Batch Normalization,也就是我們計算一個Batch數據的feature map然後在進行標準化(batch越大越接近整個數據集的分佈,效果越好)。我們根據上圖的公式可以知道\mu _{\ss }代表着我們計算的feature map每個維度(channel)的均值,注意\mu _{\ss }是一個向量不是一個值\mu _{\ss }向量的每一個元素代表着一個維度(channel)的均值。\sigma_{\ss }^{2}代表着我們計算的feature map每個維度(channel)的標準差,注意\sigma_{\ss }^{2}是一個向量不是一個值\sigma_{\ss }^{2}向量的每一個元素代表着一個維度(channel)的方差,然後根據\mu _{\ss }\sigma_{\ss }^{2}計算標準化處理後得到的值。下圖給出了一個計算均值\mu _{\ss }和方差\sigma_{\ss }^{2}的示例:

上圖展示了一個batch size爲2(兩張圖片)的Batch Normalization的計算過程,假設feature1、feature2分別是由image1、image2經過一系列卷積池化後得到的特徵矩陣,feature的channel爲2,那麼x^{(1)}代表該batch的所有feature的channel1的數據,同理x^{^{(2)}}代表該batch的所有feature的channel2的數據。然後分別計算x^{(1)}x^{^{(2)}}的均值與方差,得到我們的\mu _{\ss }\sigma_{\ss }^{2}兩個向量。然後在根據標準差計算公式分別計算每個channel的值(公式中的\epsilon是一個很小的常量,防止分母爲零的情況)。在我們訓練網絡的過程中,我們是通過一個batch一個batch的數據進行訓練的,但是我們在預測過程中通常都是輸入一張圖片進行預測,此時batch size爲1,如果在通過上述方法計算均值和方差就沒有意義了。所以我們在訓練過程中要去不斷的計算每個batch的均值和方差,並使用移動平均(moving average)的方法記錄統計的均值和方差,在我們訓練完後我們可以近似認爲我們所統計的均值和方差就等於我們整個訓練集的均值和方差。然後在我們驗證以及預測過程中,就使用我們統計得到的均值和方差進行標準化處理。

細心的同學會發現,在原論文公式中不是還有\gamma\beta兩個參數嗎?是的,\gamma是用來調整數值分佈的方差大小,\beta是用來調節數值均值的位置。這兩個參數是在反向傳播過程中學習得到的,\gamma的默認值是1,\beta的默認值是0。

 

2.使用pytorch進行試驗

你以爲你都懂了?不一定哦。剛剛說了在我們訓練過程中,均值\mu _{\ss }和方差\sigma_{\ss }^{2}是通過計算當前批次數據得到的記爲爲\mu _{now}\sigma _{now}^{2},而我們的驗證以及預測過程中所使用的均值方差是一個統計量記爲\mu _{statistic}\sigma _{statistic}^{2}\mu _{statistic}\sigma _{statistic}^{2}的具體更新策略如下,其中momentum默認取0.1:

\large \mu _{statistic+1}=(1-momentum)*\mu _{statistic}+momentum*\mu _{now}

\large \sigma _{statistic+1}^{2}=(1-momentum)*\sigma _{statistic}^{2}+momentum*\sigma _{now}^{2}

這裏要注意一下,在pytorch中對當前批次feature進行bn處理時所使用的\large \sigma _{now}^{2}總體標準差,計算公式如下:

\bg_white \large \sigma _{now}^{2}=\frac{1}{m}\sum_{i=1}^{m}(x_{i}-\mu _{now})^{2}

在更新統計量\large \sigma _{statistic}^{2}時採用的\large \sigma _{now}^{2}樣本標準差,計算公式如下:

\bg_white \large \sigma _{now}^{2}=\frac{1}{m-1}\sum_{i=1}^{m}(x_{i}-\mu _{now})^{2}

下面是我使用pytorch做的測試,代碼如下:

(1)bn_process函數是自定義的bn處理方法驗證是否和使用官方bn處理方法結果一致。在bn_process中計算輸入batch數據的每個維度(這裏的維度是channel維度)的均值和標準差(標準差等於方差開平方),然後通過計算得到的均值和總體標準差對feature每個維度進行標準化,然後使用均值和樣本標準差更新統計均值和標準差。

(2)初始化統計均值是一個元素爲0的向量,元素個數等於channel深度;初始化統計方差是一個元素爲1的向量,元素個數等於channel深度,初始化\gamma=1,\beta=0。

import numpy as np
import torch.nn as nn
import torch


def bn_process(feature, mean, var):
    feature_shape = feature.shape
    for i in range(feature_shape[1]):
        # [batch, channel, height, width]
        feature_t = feature[:, i, :, :]
        mean_t = feature_t.mean()
        # 總體標準差
        std_t1 = feature_t.std()
        # 樣本標準差
        std_t2 = feature_t.std(ddof=1)

        # bn process
        feature[:, i, :, :] = (feature[:, i, :, :] - mean_t) / std_t1
        # update calculating mean and var
        mean[i] = mean[i]*0.9 + mean_t*0.1
        var[i] = var[i]*0.9 + (std_t2**2)*0.1
    print(feature)


# 隨機生成一個batch爲2,channel爲2,height=width=2的特徵向量 
# [batch, channel, height, width]
feature1 = torch.randn(2, 2, 2, 2)
# 初始化統計均值和方差
calculate_mean = [0.0, 0.0]
calculate_var = [1.0, 1.0]
# print(feature1.numpy())

# 注意要使用copy()深拷貝
bn_process(feature1.numpy().copy(), calculate_mean, calculate_var)

bn = nn.BatchNorm2d(2)
output = bn(feature1)
print(output)

首先我在最後設置了一個斷點進行調試,查看下官方bn對feature處理後得到的統計均值和方差。我們可以發現官方提供的bn的running_mean和running_var和我們自己計算的calculate_mean和calculate_var是一模一樣的(只是精度不同)。

然後我們打印出通過自定義bn_process函數得到的輸出以及使用官方bn處理得到輸出,明顯結果是一樣的(只是精度不同):

 

3.使用BN時需要注意的問題

(1)訓練時要將traning參數設置爲True,在驗證時將trainning參數設置爲False。在pytorch中可通過創建模型的model.train()和model.eval()方法控制。

(2)batch size儘可能設置大點,設置小後表現可能很糟糕,設置的越大求的均值和方差越接近整個訓練集的均值和方差。

(3)建議將bn層放在卷積層(Conv)和激活層(例如Relu)之間,且卷積層不要使用偏置bias,因爲沒有用,參考下圖推理,即使使用了偏置bias求出的結果也是一樣的\bg_white \large y_{i}^{b}=y_{i}

最後給出李宏毅老師關於batch normalization的視頻講解:

https://www.bilibili.com/video/av9770302?p=10 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章