一、Batch Normalization概念
1.1 Batch Normalization概念
Batch Normalization:批標準化
- 批:一批數據,通常爲mini-batch
- 標準化: 0均值, 1方差
優點:
- 可以用更大學習率,加速模型收斂
- 可以不用精心設計權值初始化
- 可以不用dropout或較小的dropout
- 可以不用L2或者較小的weight decay
- 可以不用LRN(local response normalization)
《 Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》
計算方式:
輸入:一個mini-batch數據(m個),兩個待學習的參數
輸出:
- 求取mini-batch數據的均值和方差
- 對mini-batch中的每個數據標準化,是修正項,防止分母爲0
- 對上一步數據進行affine transfrom,可理解爲縮放和平移,增強Capacity
1.2 Internal Covariate Shift (ICS)
ICS:可以簡單理解爲數據尺度或分佈的變化
由上圖中的D(H1)=n*D(x)*D(W)=1可知,第一個隱藏層的輸出等於上一層的輸入的方差和二者之間權重的方差的連乘,所以如果數據的方差發生微小變化,那麼隨着網絡的加深,這個變化會越來越明顯,從而導致梯度消失或梯度爆炸
所以數據尺度或分佈發生變化,則會導致模型難以訓練
而Batch Normalization就是爲了解決這個問題而推出來的
1.3 Batch Normalization應用
1.3.1 使用BN,可以不用權值初始化
# -*- coding: utf-8 -*-
import torch
import numpy as np
import torch.nn as nn
from tools.common_tools import set_seed
set_seed(1) # 設置隨機種子
class MLP(nn.Module):
def __init__(self, neural_num, layers=100):
super(MLP, self).__init__()
self.linears = nn.ModuleList([nn.Linear(neural_num, neural_num, bias=False) for i in range(layers)])
self.bns = nn.ModuleList([nn.BatchNorm1d(neural_num) for i in range(layers)])
self.neural_num = neural_num
def forward(self, x):
for (i, linear), bn in zip(enumerate(self.linears), self.bns):
x = linear(x)
x = bn(x) # 在激活函數之前使用BN層
x = torch.relu(x)
if torch.isnan(x.std()):
print("output is nan in {} layers".format(i))
break
print("layers:{}, std:{}".format(i, x.std().item()))
return x
def initialize(self):
for m in self.modules():
if isinstance(m, nn.Linear):
# method 1
# nn.init.normal_(m.weight.data, std=1) # normal: mean=0, std=1
# method 2 kaiming
nn.init.kaiming_normal_(m.weight.data)
neural_nums = 256
layer_nums = 100
batch_size = 16
net = MLP(neural_nums, layer_nums)
# net.initialize()
inputs = torch.randn((batch_size, neural_nums)) # normal: mean=0, std=1
output = net(inputs)
print(output)
可以從上圖看到,當使用了BN,不使用權值初始化,每層的標準差依然保持的很好
1.3.2 BN應用二分類模型
# -*- coding:utf-8 -*-
"""
@file name : bn_application.py
# @author : TingsongYu https://github.com/TingsongYu
@date : 2019-11-01
@brief : nn.BatchNorm使用
"""
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from model.lenet import LeNet
from tools.my_dataset import RMBDataset
from tools.common_tools import set_seed
class LeNet_bn(nn.Module):
def __init__(self, classes):
super(LeNet_bn, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.bn1 = nn.BatchNorm2d(num_features=6)
self.conv2 = nn.Conv2d(6, 16, 5)
self.bn2 = nn.BatchNorm2d(num_features=16)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.bn3 = nn.BatchNorm1d(num_features=120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, classes)
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = F.relu(out)
out = F.max_pool2d(out, 2)
out = self.conv2(out)
out = self.bn2(out)
out = F.relu(out)
out = F.max_pool2d(out, 2)
out = out.view(out.size(0), -1)
out = self.fc1(out)
out = self.bn3(out)
out = F.relu(out)
out = F.relu(self.fc2(out))
out = self.fc3(out)
return out
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight.data, 0, 1)
m.bias.data.zero_()
set_seed(1) # 設置隨機種子
rmb_label = {"1": 0, "100": 1}
# 參數設置
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1
# ============================ step 1/5 數據 ============================
split_dir = os.path.join("..", "..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.RandomCrop(32, padding=4),
transforms.RandomGrayscale(p=0.8),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
valid_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
# 構建MyDataset實例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)
# 構建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)
# ============================ step 2/5 模型 ============================
# net = LeNet_bn(classes=2)
net = LeNet(classes=2)
# net.initialize_weights()
# ============================ step 3/5 損失函數 ============================
criterion = nn.CrossEntropyLoss() # 選擇損失函數
# ============================ step 4/5 優化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9) # 選擇優化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) # 設置學習率下降策略
# ============================ step 5/5 訓練 ============================
train_curve = list()
valid_curve = list()
iter_count = 0
# 構建 SummaryWriter
writer = SummaryWriter(comment='test_your_comment', filename_suffix="_test_your_filename_suffix")
for epoch in range(MAX_EPOCH):
loss_mean = 0.
correct = 0.
total = 0.
net.train()
for i, data in enumerate(train_loader):
iter_count += 1
# forward
inputs, labels = data
outputs = net(inputs)
# backward
optimizer.zero_grad()
loss = criterion(outputs, labels)
loss.backward()
# update weights
optimizer.step()
# 統計分類情況
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).squeeze().sum().numpy()
# 打印訓練信息
loss_mean += loss.item()
train_curve.append(loss.item())
if (i+1) % log_interval == 0:
loss_mean = loss_mean / log_interval
print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
loss_mean = 0.
# 記錄數據,保存於event file
writer.add_scalars("Loss", {"Train": loss.item()}, iter_count)
writer.add_scalars("Accuracy", {"Train": correct / total}, iter_count)
scheduler.step() # 更新學習率
# validate the model
if (epoch+1) % val_interval == 0:
correct_val = 0.
total_val = 0.
loss_val = 0.
net.eval()
with torch.no_grad():
for j, data in enumerate(valid_loader):
inputs, labels = data
outputs = net(inputs)
loss = criterion(outputs, labels)
_, predicted = torch.max(outputs.data, 1)
total_val += labels.size(0)
correct_val += (predicted == labels).squeeze().sum().numpy()
loss_val += loss.item()
valid_curve.append(loss.item())
print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val, correct / total))
# 記錄數據,保存於event file
writer.add_scalars("Loss", {"Valid": loss.item()}, iter_count)
writer.add_scalars("Accuracy", {"Valid": correct / total}, iter_count)
train_x = range(len(train_curve))
train_y = train_curve
train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由於valid中記錄的是epochloss,需要對記錄點進行轉換到iterations
valid_y = valid_curve
plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')
plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()
從上圖可以看到,在訓練時,因爲沒有經過權值初始化,在80到100個iteration時,loss會發生激增
在使用了權值初始化後,可以看到,在第20到40個iteration時會出現激增,而往後,Loss趨於減小,並逐漸到了一個很好的結果
可以看到在加入了BN層後,雖然loss也會出現激增,但是幅度小,而且始終保持在一個良好尺度範圍內
二、PyTorch的Batch Normalization 1d/2d/3d實現
2.1 基類——_BatchNorm
__init__(self,
num_features,
eps=1e-5,
momentum=0.1,
affine=True,
track_running_stats=True)
_BatchNorm
- nn.BatchNorm1d
- nn.BatchNorm2d
- nn.BatchNorm3d
參數:
- num_features:一個樣本特徵數量(最重要)
- eps:分母修正項
- momentum:指數加權平均估計當前mean/var
- affine:是否需要affine transform
- track_running_stats:是訓練狀態,還是測試狀態
2.2 nn.BatchNorm1d/2d/3d
- nn.BatchNorm1d
- nn.BatchNorm2d
- nn.BatchNorm3d
主要屬性:
- running_mean:均值
- running_var:方差
- weight: affine transform中的gamma
- bias: affine transform中的beta
訓練時:均值和方差採用指數加權平均計算
running_mean = (1 - momentum) * pre_running_mean + momentum * mean_t
running_var = (1 - momentum) * pre_running_var + momentum * var_t
測試時:當前統計值
代碼示例:
# -*- coding: utf-8 -*-
import torch
import numpy as np
import torch.nn as nn
from tools.common_tools import set_seed
set_seed(1) # 設置隨機種子
# ======================================== nn.BatchNorm1d
flag = 1
# flag = 0
if flag:
batch_size = 3
num_features = 5
momentum = 0.3
features_shape = (1)
feature_map = torch.ones(features_shape) # 1D
feature_maps = torch.stack([feature_map*(i+1) for i in range(num_features)], dim=0) # 2D
feature_maps_bs = torch.stack([feature_maps for i in range(batch_size)], dim=0) # 3D
print("input data:\n{} shape is {}".format(feature_maps_bs, feature_maps_bs.shape))
bn = nn.BatchNorm1d(num_features=num_features, momentum=momentum)
running_mean, running_var = 0, 1
for i in range(2):
outputs = bn(feature_maps_bs)
print("\niteration:{}, running mean: {} ".format(i, bn.running_mean))
print("iteration:{}, running var:{} ".format(i, bn.running_var))
mean_t, var_t = 2, 0
running_mean = (1 - momentum) * running_mean + momentum * mean_t
running_var = (1 - momentum) * running_var + momentum * var_t
print("iteration:{}, 第二個特徵的running mean: {} ".format(i, running_mean))
print("iteration:{}, 第二個特徵的running var:{}".format(i, running_var))
這裏當前用來對數據進行標準化的均值和方差,不是隻是計算當前mini-batch所計算的,而是會考慮之前mini-batch數據的信息,綜合考慮去估計去均值和方差
# -*- coding: utf-8 -*-
import torch
import numpy as np
import torch.nn as nn
from tools.common_tools import set_seed
set_seed(1) # 設置隨機種子
# ======================================== nn.BatchNorm2d
flag = 1
# flag = 0
if flag:
batch_size = 3
num_features = 6
momentum = 0.3
features_shape = (2, 2)
feature_map = torch.ones(features_shape) # 2D
feature_maps = torch.stack([feature_map*(i+1) for i in range(num_features)], dim=0) # 3D
feature_maps_bs = torch.stack([feature_maps for i in range(batch_size)], dim=0) # 4D
print("input data:\n{} shape is {}".format(feature_maps_bs, feature_maps_bs.shape))
bn = nn.BatchNorm2d(num_features=num_features, momentum=momentum)
running_mean, running_var = 0, 1
for i in range(2):
outputs = bn(feature_maps_bs)
print("\niter:{}, running_mean.shape: {}".format(i, bn.running_mean.shape))
print("iter:{}, running_var.shape: {}".format(i, bn.running_var.shape))
print("iter:{}, weight.shape: {}".format(i, bn.weight.shape))
print("iter:{}, bias.shape: {}".format(i, bn.bias.shape))
由上圖可知,BN是在特徵的數量上計算的
# -*- coding: utf-8 -*-
import torch
import numpy as np
import torch.nn as nn
from tools.common_tools import set_seed
set_seed(1) # 設置隨機種子
# ======================================== nn.BatchNorm3d
flag = 1
# flag = 0
if flag:
batch_size = 3
num_features = 4
momentum = 0.3
features_shape = (2, 2, 3)
feature = torch.ones(features_shape) # 3D
feature_map = torch.stack([feature * (i + 1) for i in range(num_features)], dim=0) # 4D
feature_maps = torch.stack([feature_map for i in range(batch_size)], dim=0) # 5D
print("input data:\n{} shape is {}".format(feature_maps, feature_maps.shape))
bn = nn.BatchNorm3d(num_features=num_features, momentum=momentum)
running_mean, running_var = 0, 1
for i in range(2):
outputs = bn(feature_maps)
print("\niter:{}, running_mean.shape: {}".format(i, bn.running_mean.shape))
print("iter:{}, running_var.shape: {}".format(i, bn.running_var.shape))
print("iter:{}, weight.shape: {}".format(i, bn.weight.shape))
print("iter:{}, bias.shape: {}".format(i, bn.bias.shape))