DL基石-神經網絡的批標準化

DL基石-神經網絡的批標準化

訓練學習系統的一個主要假設是在整個訓練過程中輸入分佈是保持不變的。對於簡單地將輸入數據映射到某些適當輸出的線性模型,這種條件總是能滿足的,但在處理由多層疊加而成的神經網絡時,情況就不一樣了。
在這樣的體系結構中,每一層的輸入都受到前面所有層參數的影響(隨着網絡變得更深,對網絡參數的小變化會被放大),因此,在一層內的反向傳播步驟中所做的一個小的變化可以產生另一層輸入的一個巨大變化,並在最後改變特徵的映射分佈。在訓練過程中,每一層都需要不斷地適應前一層得到的新分佈,這就減慢了收斂速度。
批標準化克服了這一問題,同時減少了訓練過程中內層的協方差移位(由於訓練過程中網絡參數的變化而導致的網絡激活分佈變化)
本文將討論以下內容


  • 批標準化如何減少內部協方差移位,如何改進神經網絡的訓練。
  • 如何在PyTorch中實現批標準化層。
  • 一些簡單的實驗展示了使用批標準化的優點。
    減少內部協方差移位
    減少神經網絡內部協方差移位的不良影響的一種方法是對層輸入進行歸一化,這個操作不僅使輸入具有相同的分佈,而且還使每個輸入都白化(白化是對原始數據x實現的一種變換,使變換之後數據的協方差矩陣爲單位陣),該方法是由一些相關研究提出的,這些研究表明,如果對網絡的輸入進行白化,則網絡訓練收斂得更快,因此,增強各層輸入的白化是網絡的一個理想特性。
    然而,每一層輸入的完全白化是昂貴的,並且不是完全可微的。批標準化通過考慮兩個假設克服了這個問題:


  • 我們將獨立地對每個標量特徵進行歸一化(通過設置均值爲0和方差爲1),而不是對層的輸入和輸出的特徵進行白化。
  • 我們不使用整個數據集來進行標準化,而是使用mini-batch,每個mini-batch生成每個激活層的平均值和方差的估計值。
    對於具有d維輸入的層x = (x1, x2, ..xd),我們得到了以下歸一化公式(對batch B的期望和方差進行計算):
    DL基石-神經網絡的批標準化
    然而,簡單地標準化一個層的每個輸入可能會改變層所能表示的內容。例如,對一個sigmoid的輸入進行歸一化會將其約束到非線性的線性狀態,這樣的行爲對網絡來說是不可取的,因爲它會降低其非線性的能力(它會相當於一個單層網絡)。
    DL基石-神經網絡的批標準化
    爲了解決這個問題,批標準化確保了插入到網絡中的轉換可以表示爲單位轉換(模型仍然在每個層學習到一些參數,這些參數在沒有線性映射的情況下調整從上一層接收到的激活),這是通過引入一對可學習參數gamma_k和beta_k來實現的,這兩個參數根據模型學習的內容進行縮放和移動標準化值。
    最後,得到的層的輸入(基於前一層的輸出x)爲:
    DL基石-神經網絡的批標準化
    批標準化算法
    訓練時
    全連接層
    全連接層的實現非常簡單。我們只需要得到每個批次的均值和方差,然後用之前給出的alpha和beata參數來縮放和移動。
    在反向傳播期間,我們將使用反向傳播來更新這兩個參數。












mean = torch.mean(X, axis=0)
variance = torch.mean((X-mean)**2, axis=0)
X_hat = (X-mean) * 1.0 /torch.sqrt(variance + eps)
out = gamma * X_hat + beta

卷積層
卷積層的實現幾乎與以前一樣,我們只需要執行一些改造,以適應我們從上一層獲得的輸入結構。


N, C, H, W = X.shape
mean = torch.mean(X, axis = (0, 2, 3))
variance = torch.mean((X - mean.reshape((1, C, 1, 1))) ** 2, axis=(0, 2, 3))
X_hat = (X - mean.reshape((1, C, 1, 1))) * 1.0 / torch.sqrt(variance.reshape((1, C, 1, 1)) + eps)
out = gamma.reshape((1, C, 1, 1)) * X_hat + beta.reshape((1, C, 1, 1))

在PyTorch中,反向傳播非常容易處理,這裏一件重要事情是指定alpha和beta在反向傳播階段更新它們的參數。
爲此,我們將在層中將它們聲明爲nn.Parameter(),並使用隨機值初始化它們。
推理時
在推理過程中,我們希望網絡的輸出只依賴於輸入,因此我們不能考慮之前考慮mini-batch的統計數據(它們與mini-batch大小相關,因此它們根據數據而變化)。爲了確保我們有一個固定的期望和方差,我們需要使用整個數據集來計算這些值,而不是隻考慮mini-batch,然而,就時間和計算而言,爲所有數據集計算這些統計信息是相當昂貴的。
論文中提出的方法是使用我們在訓練期間計算的滑動統計,我們使用參數beta(動量)調整當前批次計算的期望重要性:
DL基石-神經網絡的批標準化
該滑動平均值存儲在一個全局變量中,該全局變量在訓練階段更新。爲了在訓練期間將這個滑動平均值存儲在我們的層中,我們可以使用緩衝區,當我們使用PyTorch的register_buffer()方法實例化我們的層時,我們將初始化這些緩衝區。
最後一個模塊
最後一個模塊由前面描述的所有塊組成。我們在輸入數據的形狀上添加一個條件,以瞭解我們處理的是全連接層還是卷積層。
這裏需要注意的一件重要事情是,我們只需要實現forward()方法,因爲我們的類繼承自nn.Module,我們就可以自動得到backward()函數。









class CustomBatchNorm(nn.Module):

    def __init__(self, in_size, momentum=0.9, eps = 1e-5):
        super(CustomBatchNorm, self).__init__()

        self.momentum = momentum
        self.insize = in_size
        self.eps = eps

        U = uniform.Uniform(torch.tensor([0.0]), torch.tensor([1.0]))
        self.gamma = nn.Parameter(U.sample(torch.Size([self.insize])).view(self.insize))
        self.beta = nn.Parameter(torch.zeros(self.insize))

        self.register_buffer('running_mean', torch.zeros(self.insize))
        self.register_buffer('running_var', torch.ones(self.insize))

        self.running_mean.zero_()
        self.running_var.fill_(1)

    def forward(self, input):

        X = input

        if len(X.shape) not in (2, 4):
            raise ValueError("only support dense or 2dconv")

        #全連接層
        elif len(X.shape) == 2:
            if self.training:
                mean = torch.mean(X, axis=0)
                variance = torch.mean((X-mean)**2, axis=0)

                self.running_mean = (self.momentum * self.running_mean) + (1.0-self.momentum) * mean
                self.running_var = (self.momentum * self.running_var) + (1.0-self.momentum) * (input.shape[0]/(input.shape[0]-1)*variance)

            else:
                mean = self.running_mean
                variance = self.running_var

            X_hat = (X-mean) * 1.0 /torch.sqrt(variance + self.eps)
            out = self.gamma * X_hat + self.beta

                # 卷積層
        elif len(X.shape) == 4:
            if self.training:
                N, C, H, W = X.shape
                mean = torch.mean(X, axis = (0, 2, 3))
                variance = torch.mean((X - mean.reshape((1, C, 1, 1))) ** 2, axis=(0, 2, 3))

                self.running_mean = (self.momentum * self.running_mean) + (1.0-self.momentum) * mean
                self.running_var = (self.momentum * self.running_var) + (1.0-self.momentum) * (input.shape[0]/(input.shape[0]-1)*variance)
            else:
                mean = self.running_mean
                var = self.running_var

            X_hat = (X - mean.reshape((1, C, 1, 1))) * 1.0 / torch.sqrt(variance.reshape((1, C, 1, 1)) + self.eps)
            out = self.gamma.reshape((1, C, 1, 1)) * X_hat + self.beta.reshape((1, C, 1, 1))

        return out

實驗MNIST
爲了觀察批處理歸一化對訓練的影響,我們可以比較沒有批處理歸一化的簡單神經網絡和有批處理歸一化的神經網絡的收斂速度。
爲了簡單起見,我們在MNIST數據集上訓練這兩個簡單的全連接網絡,不進行預處理(只應用數據標準化)。
沒有批標準化的網絡架構



class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(28 * 28, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

有批標準化的網絡架構

class SimpleNetBN(nn.Module):
    def __init__(self):
        super(SimpleNetBN, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(28 * 28, 64),
            CustomBatchNorm(64),
            nn.ReLU(),
            nn.Linear(64, 128),
            CustomBatchNorm(128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

結果
下圖顯示了在我們的SimpleNet的第一層之後獲得的激活分佈,我們可以看到,即使經過20個epoch,分佈仍然是高斯分佈(在訓練過程中學習到的小尺度和移位)。
DL基石-神經網絡的批標準化
我們也可以看到收斂速度方面的巨大進步。綠色曲線(帶有批標準化)表明,我們可以更快地收斂到具有批標準化的最優解。
DL基石-神經網絡的批標準化
實驗結果詳見(https://github.com/sinitame/neuralnetworks-ents/blob/master/batch_normalization/batch_normaliz.ipynb)
結論
使用批標準化進行訓練的優點
一個mini-batch處理的損失梯度是對訓練集梯度的估計,訓練的質量隨着批處理大小的增加而提高。
由於gpu提供的並行性,批處理大小上的計算要比單個示例的多次計算效率高得多。
在每一層使用批處理歸一化來減少內部方差的移位,大大提高了網絡的學習效率。
原文鏈接:https://towardsdatascience.com/understanding-batch-normalization-for-neural-networks-1cd269786fa6










發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章