22Normalizaiton_layers

一、爲什麼要Normalization?

在這裏插入圖片描述
ICS問題:由於數據尺度/分佈異常,導致訓練困難

由上圖中的D(H1)=n*D(x)*D(W)=1可知,第一個隱藏層的輸出等於上一層的輸入的方差和二者之間權重的方差的連乘,所以如果數據的方差發生微小變化,那麼隨着網絡的加深,這個變化會越來越明顯,從而導致梯度消失或梯度爆炸
所以數據尺度或分佈發生變化,則會導致模型難以訓練

進行Normalization就能控制和約束數據的尺度,使得數據在一個良好的尺度和分佈範圍內,從而有助於模型的訓練

二、常見的Normalization方法

在這裏插入圖片描述

2.1 Layer Normalization( LN)

在這裏插入圖片描述
說明:
因爲BN是從特徵數的維度出發,按照batch計算均值和方差,而在變長的網絡中,如RNN,沒有辦法按照BN的計算方式來計算均值和方差
如上圖中,不同的batch中的數據對應的特徵數不同,所以沒有辦法按照batch計算均值和方差

2.1.1 nn.LayerNorm

nn.LayerNorm(normalized_shape,
			 eps=1e-05,
			 elementwise_affine=True)

主要參數:

  • normalized_shape:該層特徵形狀,
  • eps:分母修正項
  • elementwise_affine:是否需要affine transform

注意:
normalized_shape參數輸入的特徵形狀要求是C*H*W,而特徵圖的shape是B*C*H*W,所以輸入時要注意處理——feature_maps_bs.size()[1:]

# -*- coding: utf-8 -*-
import torch
import numpy as np
import torch.nn as nn
from tools.common_tools import set_seed


set_seed(1)  # 設置隨機種子

# ======================================== nn.layer norm
flag = 1
# flag = 0
if flag:
    batch_size = 8
    num_features = 6

    features_shape = (3, 4)

    feature_map = torch.ones(features_shape)  # 2D
    feature_maps = torch.stack([feature_map * (i + 1) for i in range(num_features)], dim=0)  # 3D
    feature_maps_bs = torch.stack([feature_maps for i in range(batch_size)], dim=0)  # 4D

    # feature_maps_bs shape is [8, 6, 3, 4],  B * C * H * W
    # ln = nn.LayerNorm(feature_maps_bs.size()[1:], elementwise_affine=True)
    # ln = nn.LayerNorm(feature_maps_bs.size()[1:], elementwise_affine=False)
    # ln = nn.LayerNorm([6, 3, 4])
    ln = nn.LayerNorm([6, 3])

    output = ln(feature_maps_bs)

    print("Layer Normalization")
    print(ln.weight.shape)
    print(feature_maps_bs[0, ...])
    print(output[0, ...])

在這裏插入圖片描述
注意:這裏的weight是對應Normalization公式裏的γ和β,由weight的size可以看到,LN的確是按一個batch逐元素計算的
在這裏插入圖片描述
當elementwise_affine設置爲false,ln.weight就沒有,所以說明該參數就是對應Normalization公式裏的γ和β
在這裏插入圖片描述
LN可以通過normalized_shape參數,使得在指定的shape上進行Normalization
在這裏插入圖片描述
注意:指定的shape,必須是按照BCHW,從後往前的連續形式輸入shape,如果不連續,或者不是從W開始的,就會報錯

2.2 Instance Normalization( IN)

在這裏插入圖片描述
說明:
該方法起因於圖像領域,因爲一個batch圖像數據它們有不同的風格和內容,所以不能將其混爲一談直接計算均值和方差,所以就提出了逐通道的計算均值和方差
這裏的逐通道的計算均值和方差是按照intance的,也就是按照每一個特徵圖的通道計算

2.2.1 nn.InstanceNorm

nn.InstanceNorm2d(num_features,
				  eps=1e-05,
				  momentum=0.1,
				  affine=False,
				  track_running_stats=False)

主要參數:

  • num_features:一個樣本特徵數量(最重要)
  • eps:分母修正項
  • momentum:指數加權平均估計當前mean/var
  • affine:是否需要affine transform
  • track_running_stats:是訓練狀態,還是測試狀態
# -*- coding: utf-8 -*-
import torch
import numpy as np
import torch.nn as nn
from tools.common_tools import set_seed


set_seed(1)  # 設置隨機種子
# ======================================== nn.instance norm 2d
flag = 1
# flag = 0
if flag:

    batch_size = 3
    num_features = 3
    momentum = 0.3

    features_shape = (2, 2)

    feature_map = torch.ones(features_shape)    # 2D
    feature_maps = torch.stack([feature_map * (i + 1) for i in range(num_features)], dim=0)  # 3D
    feature_maps_bs = torch.stack([feature_maps for i in range(batch_size)], dim=0)  # 4D

    print("Instance Normalization")
    print("input data:\n{} shape is {}".format(feature_maps_bs, feature_maps_bs.shape))

    instance_n = nn.InstanceNorm2d(num_features=num_features, momentum=momentum)

    for i in range(1):
        outputs = instance_n(feature_maps_bs)

        print(outputs)
        # print("\niter:{}, running_mean.shape: {}".format(i, bn.running_mean.shape))
        # print("iter:{}, running_var.shape: {}".format(i, bn.running_var.shape))
        # print("iter:{}, weight.shape: {}".format(i, bn.weight.shape))
        # print("iter:{}, bias.shape: {}".format(i, bn.bias.shape))


運行結果:
在這裏插入圖片描述
在這裏插入圖片描述
由上圖可知,output結果都是0,因爲輸入數據的每一個通道的數據都是一樣的,每個通道的數據的均值和該通道上的數是一樣的,所以求均值和方差的結果都是0,由此就可知,IN確實是以實例出發,按照通道計算均值和方差

2.3 Group Normalization( GN)

在這裏插入圖片描述
說明:
樣本數越多,估計的均值和方差就會越準,而在一些模型數據量特別大的情況下,GPU只能容納兩個甚至一個batch的數據,因此在這種情況下均值和方差的估計值不準,導致BN方法失效

2.3.1 nn.GroupNorm

nn.GroupNorm(num_groups,
			 num_channels,
			 eps=1e-05,
			 affine=True)

主要參數:

  • num_groups:分組數
  • num_channels:通道數(特徵數)
  • eps:分母修正項
  • affine:是否需要affine transform
# -*- coding: utf-8 -*-
import torch
import numpy as np
import torch.nn as nn
from tools.common_tools import set_seed


set_seed(1)  # 設置隨機種子
# ======================================== nn.grop norm
flag = 1
# flag = 0
if flag:

    batch_size = 2
    num_features = 4
    num_groups = 3   # 3 Expected number of channels in input to be divisible by num_groups

    features_shape = (2, 2)

    feature_map = torch.ones(features_shape)    # 2D
    feature_maps = torch.stack([feature_map * (i + 1) for i in range(num_features)], dim=0)  # 3D
    feature_maps_bs = torch.stack([feature_maps * (i + 1) for i in range(batch_size)], dim=0)  # 4D

    gn = nn.GroupNorm(num_groups, num_features)
    outputs = gn(feature_maps_bs)

    print("Group Normalization")
    print(gn.weight.shape)
    print(outputs[0])

在這裏插入圖片描述
注意到這裏的gn.weight.shape等於4,與num_features相同,所以說明了GN的γ和β是逐通道計算的
在這裏插入圖片描述
注意:設置的num_groups必須是要能被通道數整除的,否則會報錯,通常會設置爲2的n次冪

三、Normalization小結

在這裏插入圖片描述
BN:按照batch size的方向計算均值和方差,而且往往是在batch數較多的情況使用
LN:按照整個網絡層計算均值和方差
IN:以每個feature map出發,按照通道來計算均值和方差
GN:對feature map進行分組,按照一個group來計算均值和方差

發佈了111 篇原創文章 · 獲贊 9 · 訪問量 9126
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章