21Batch Normalization

一、Batch Normalization概念

1.1 Batch Normalization概念

Batch Normalization:批標準化

  • 批:一批數據,通常爲mini-batch
  • 標準化: 0均值, 1方差

優點:

  1. 可以用更大學習率,加速模型收斂
  2. 可以不用精心設計權值初始化
  3. 可以不用dropout或較小的dropout
  4. 可以不用L2或者較小的weight decay
  5. 可以不用LRN(local response normalization)

《 Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》

計算方式:
在這裏插入圖片描述
輸入:一個mini-batch數據(m個),兩個待學習的參數γ,β\gamma,\beta

輸出:

  • 求取mini-batch數據的均值和方差
  • 對mini-batch中的每個數據標準化,ϵ\epsilon是修正項,防止分母爲0
  • 對上一步數據進行affine transfrom,可理解爲縮放和平移,增強Capacity

1.2 Internal Covariate Shift (ICS)

在這裏插入圖片描述
ICS:可以簡單理解爲數據尺度或分佈的變化

由上圖中的D(H1)=n*D(x)*D(W)=1可知,第一個隱藏層的輸出等於上一層的輸入的方差和二者之間權重的方差的連乘,所以如果數據的方差發生微小變化,那麼隨着網絡的加深,這個變化會越來越明顯,從而導致梯度消失或梯度爆炸
所以數據尺度或分佈發生變化,則會導致模型難以訓練
而Batch Normalization就是爲了解決這個問題而推出來的

1.3 Batch Normalization應用

1.3.1 使用BN,可以不用權值初始化

# -*- coding: utf-8 -*-

import torch
import numpy as np
import torch.nn as nn
from tools.common_tools import set_seed

set_seed(1)  # 設置隨機種子


class MLP(nn.Module):
    def __init__(self, neural_num, layers=100):
        super(MLP, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(neural_num, neural_num, bias=False) for i in range(layers)])
        self.bns = nn.ModuleList([nn.BatchNorm1d(neural_num) for i in range(layers)])
        self.neural_num = neural_num

    def forward(self, x):

        for (i, linear), bn in zip(enumerate(self.linears), self.bns):
            x = linear(x)
            x = bn(x)          # 在激活函數之前使用BN層
            x = torch.relu(x)

            if torch.isnan(x.std()):
                print("output is nan in {} layers".format(i))
                break

            print("layers:{}, std:{}".format(i, x.std().item()))

        return x

    def initialize(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):

                # method 1
                # nn.init.normal_(m.weight.data, std=1)    # normal: mean=0, std=1

                # method 2 kaiming
                nn.init.kaiming_normal_(m.weight.data)


neural_nums = 256
layer_nums = 100
batch_size = 16

net = MLP(neural_nums, layer_nums)
# net.initialize()

inputs = torch.randn((batch_size, neural_nums))  # normal: mean=0, std=1

output = net(inputs)
print(output)

在這裏插入圖片描述
可以從上圖看到,當使用了BN,不使用權值初始化,每層的標準差依然保持的很好

1.3.2 BN應用二分類模型

# -*- coding:utf-8 -*-
"""
@file name  : bn_application.py
# @author   : TingsongYu https://github.com/TingsongYu
@date       : 2019-11-01
@brief      : nn.BatchNorm使用
"""
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from model.lenet import LeNet
from tools.my_dataset import RMBDataset
from tools.common_tools import set_seed


class LeNet_bn(nn.Module):
    def __init__(self, classes):
        super(LeNet_bn, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.bn1 = nn.BatchNorm2d(num_features=6)

        self.conv2 = nn.Conv2d(6, 16, 5)
        self.bn2 = nn.BatchNorm2d(num_features=16)

        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.bn3 = nn.BatchNorm1d(num_features=120)

        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, classes)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)

        out = F.max_pool2d(out, 2)

        out = self.conv2(out)
        out = self.bn2(out)
        out = F.relu(out)

        out = F.max_pool2d(out, 2)

        out = out.view(out.size(0), -1)

        out = self.fc1(out)
        out = self.bn3(out)
        out = F.relu(out)

        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight.data, 0, 1)
                m.bias.data.zero_()


set_seed(1)  # 設置隨機種子
rmb_label = {"1": 0, "100": 1}

# 參數設置
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1

# ============================ step 1/5 數據 ============================

split_dir = os.path.join("..", "..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomGrayscale(p=0.8),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

valid_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

# 構建MyDataset實例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# 構建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

# ============================ step 2/5 模型 ============================

# net = LeNet_bn(classes=2)
net = LeNet(classes=2)
# net.initialize_weights()

# ============================ step 3/5 損失函數 ============================
criterion = nn.CrossEntropyLoss()                                                   # 選擇損失函數

# ============================ step 4/5 優化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)                        # 選擇優化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)     # 設置學習率下降策略

# ============================ step 5/5 訓練 ============================
train_curve = list()
valid_curve = list()

iter_count = 0
# 構建 SummaryWriter
writer = SummaryWriter(comment='test_your_comment', filename_suffix="_test_your_filename_suffix")

for epoch in range(MAX_EPOCH):

    loss_mean = 0.
    correct = 0.
    total = 0.

    net.train()
    for i, data in enumerate(train_loader):

        iter_count += 1

        # forward
        inputs, labels = data
        outputs = net(inputs)

        # backward
        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        loss.backward()

        # update weights
        optimizer.step()

        # 統計分類情況
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).squeeze().sum().numpy()

        # 打印訓練信息
        loss_mean += loss.item()
        train_curve.append(loss.item())
        if (i+1) % log_interval == 0:
            loss_mean = loss_mean / log_interval
            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
            loss_mean = 0.

        # 記錄數據,保存於event file
        writer.add_scalars("Loss", {"Train": loss.item()}, iter_count)
        writer.add_scalars("Accuracy", {"Train": correct / total}, iter_count)

    scheduler.step()  # 更新學習率

    # validate the model
    if (epoch+1) % val_interval == 0:

        correct_val = 0.
        total_val = 0.
        loss_val = 0.
        net.eval()
        with torch.no_grad():
            for j, data in enumerate(valid_loader):
                inputs, labels = data
                outputs = net(inputs)
                loss = criterion(outputs, labels)

                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).squeeze().sum().numpy()

                loss_val += loss.item()

            valid_curve.append(loss.item())
            print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val, correct / total))

            # 記錄數據,保存於event file
            writer.add_scalars("Loss", {"Valid": loss.item()}, iter_count)
            writer.add_scalars("Accuracy", {"Valid": correct / total}, iter_count)

train_x = range(len(train_curve))
train_y = train_curve

train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由於valid中記錄的是epochloss,需要對記錄點進行轉換到iterations
valid_y = valid_curve

plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')

plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()


在這裏插入圖片描述
從上圖可以看到,在訓練時,因爲沒有經過權值初始化,在80到100個iteration時,loss會發生激增
在這裏插入圖片描述
在使用了權值初始化後,可以看到,在第20到40個iteration時會出現激增,而往後,Loss趨於減小,並逐漸到了一個很好的結果

在這裏插入圖片描述
可以看到在加入了BN層後,雖然loss也會出現激增,但是幅度小,而且始終保持在一個良好尺度範圍內

二、PyTorch的Batch Normalization 1d/2d/3d實現

2.1 基類——_BatchNorm

__init__(self, 
		 num_features,
		 eps=1e-5,
	     momentum=0.1,
	     affine=True,
         track_running_stats=True)

_BatchNorm

  • nn.BatchNorm1d
  • nn.BatchNorm2d
  • nn.BatchNorm3d

參數:

  • num_features:一個樣本特徵數量(最重要)
  • eps:分母修正項
  • momentum:指數加權平均估計當前mean/var
  • affine:是否需要affine transform
  • track_running_stats:是訓練狀態,還是測試狀態

2.2 nn.BatchNorm1d/2d/3d

  • nn.BatchNorm1d
  • nn.BatchNorm2d
  • nn.BatchNorm3d
    在這裏插入圖片描述

主要屬性:

  • running_mean:均值
  • running_var:方差
  • weight: affine transform中的gamma
  • bias: affine transform中的beta

訓練時:均值和方差採用指數加權平均計算
running_mean = (1 - momentum) * pre_running_mean + momentum * mean_t
running_var = (1 - momentum) * pre_running_var + momentum * var_t

測試時:當前統計值

在這裏插入圖片描述
代碼示例:

# -*- coding: utf-8 -*-

import torch
import numpy as np
import torch.nn as nn
from tools.common_tools import set_seed



set_seed(1)  # 設置隨機種子

# ======================================== nn.BatchNorm1d
flag = 1
# flag = 0
if flag:

    batch_size = 3
    num_features = 5
    momentum = 0.3

    features_shape = (1)

    feature_map = torch.ones(features_shape)                                                    # 1D
    feature_maps = torch.stack([feature_map*(i+1) for i in range(num_features)], dim=0)         # 2D
    feature_maps_bs = torch.stack([feature_maps for i in range(batch_size)], dim=0)             # 3D
    print("input data:\n{} shape is {}".format(feature_maps_bs, feature_maps_bs.shape))

    bn = nn.BatchNorm1d(num_features=num_features, momentum=momentum)

    running_mean, running_var = 0, 1

    for i in range(2):
        outputs = bn(feature_maps_bs)

        print("\niteration:{}, running mean: {} ".format(i, bn.running_mean))
        print("iteration:{}, running var:{} ".format(i, bn.running_var))

        mean_t, var_t = 2, 0

        running_mean = (1 - momentum) * running_mean + momentum * mean_t
        running_var = (1 - momentum) * running_var + momentum * var_t

        print("iteration:{}, 第二個特徵的running mean: {} ".format(i, running_mean))
        print("iteration:{}, 第二個特徵的running var:{}".format(i, running_var))


在這裏插入圖片描述
這裏當前用來對數據進行標準化的均值和方差,不是隻是計算當前mini-batch所計算的,而是會考慮之前mini-batch數據的信息,綜合考慮去估計去均值和方差

# -*- coding: utf-8 -*-

import torch
import numpy as np
import torch.nn as nn
from tools.common_tools import set_seed



set_seed(1)  # 設置隨機種子
# ======================================== nn.BatchNorm2d
flag = 1
# flag = 0
if flag:

    batch_size = 3
    num_features = 6
    momentum = 0.3
    
    features_shape = (2, 2)

    feature_map = torch.ones(features_shape)                                                    # 2D
    feature_maps = torch.stack([feature_map*(i+1) for i in range(num_features)], dim=0)         # 3D
    feature_maps_bs = torch.stack([feature_maps for i in range(batch_size)], dim=0)             # 4D

    print("input data:\n{} shape is {}".format(feature_maps_bs, feature_maps_bs.shape))

    bn = nn.BatchNorm2d(num_features=num_features, momentum=momentum)

    running_mean, running_var = 0, 1

    for i in range(2):
        outputs = bn(feature_maps_bs)

        print("\niter:{}, running_mean.shape: {}".format(i, bn.running_mean.shape))
        print("iter:{}, running_var.shape: {}".format(i, bn.running_var.shape))

        print("iter:{}, weight.shape: {}".format(i, bn.weight.shape))
        print("iter:{}, bias.shape: {}".format(i, bn.bias.shape))



在這裏插入圖片描述
由上圖可知,BN是在特徵的數量上計算的

# -*- coding: utf-8 -*-

import torch
import numpy as np
import torch.nn as nn
from tools.common_tools import set_seed



set_seed(1)  # 設置隨機種子


# ======================================== nn.BatchNorm3d
flag = 1
# flag = 0
if flag:

    batch_size = 3
    num_features = 4
    momentum = 0.3

    features_shape = (2, 2, 3)

    feature = torch.ones(features_shape)                                                # 3D
    feature_map = torch.stack([feature * (i + 1) for i in range(num_features)], dim=0)  # 4D
    feature_maps = torch.stack([feature_map for i in range(batch_size)], dim=0)         # 5D

    print("input data:\n{} shape is {}".format(feature_maps, feature_maps.shape))

    bn = nn.BatchNorm3d(num_features=num_features, momentum=momentum)

    running_mean, running_var = 0, 1

    for i in range(2):
        outputs = bn(feature_maps)

        print("\niter:{}, running_mean.shape: {}".format(i, bn.running_mean.shape))
        print("iter:{}, running_var.shape: {}".format(i, bn.running_var.shape))

        print("iter:{}, weight.shape: {}".format(i, bn.weight.shape))
        print("iter:{}, bias.shape: {}".format(i, bn.bias.shape))



在這裏插入圖片描述

發佈了111 篇原創文章 · 獲贊 9 · 訪問量 9125
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章