BN 存在哪些問題:
1. BN 依賴大batch size, 當 batch size 太小時, batch statistics 變得不準確; 而顯存限制了batch size變大,尤其在檢測、分割等比較佔用顯存的模型上。 batch size上又是一個工程問題, 畢竟去年的coco,Face++主要贏在大batch上,這是最重要的motivation。
2. BN要求batch分佈比較理想, 因爲BN是沿着[N, H, W]進行統計,在複雜的任務中batch內的樣本未符合i.i.d.,比如video裏的連續幀, 比如detection box/mask head 裏, 一個batch裏的512個proposals 是高度關聯甚至是重複的。
3. Train/Test不一致, 訓練時通過指數滑動平均(EMA)計算出來的 running_mean, running_vars到最後雖然也是能夠收斂的,但是測試集和訓練集數據分佈往往並不完全一致,會造成模型在training/testing的性能差異。
GN做了什麼?
1. GN不依賴batch size, group 是指對 channels進行 grouping,然後沿着[H, W, C/G] 進行統計,計算mean and vars, 擺脫了對N的依賴。由於是per-N 進行統計的, 那麼就不要求batch內的N個樣本符合i.i.d.
2. GN在 testing time 也會根據不同的輸入計算不同的mean 和 vars, 並不像BN那樣使用training時的統計值,不存在Train/Test不一致的問題。
GN真的比BN好用嗎?
在大batch上, BN依然很有優勢,在小batch上,論文 declare 具有優勢,實際效果還要case by case去驗證。
GN 和 BN的實現(Tensorflow 版)
def GroupNorm(x, group, gamma_initializer=tf.constant_initializer(1.)):
"""
https://arxiv.org/abs/1803.08494
"""
shape = x.get_shape().as_list()
ndims = len(shape)
assert ndims == 4, shape
chan = shape[1]
assert chan % group == 0, chan
group_size = chan // group
orig_shape = tf.shape(x)
h, w = orig_shape[2], orig_shape[3]
x = tf.reshape(x, tf.stack([-1, group, group_size, h, w]))
mean, var = tf.nn.moments(x, [2, 3, 4], keep_dims=True)
new_shape = [1, group, group_size, 1, 1]
beta = tf.get_variable('beta', [chan], initializer=tf.constant_initializer())
beta = tf.reshape(beta, new_shape)
gamma = tf.get_variable('gamma', [chan], initializer=gamma_initializer)
gamma = tf.reshape(gamma, new_shape)
out = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-5, name='output')
return tf.reshape(out, orig_shape, name='output')
def BatchNorm(x, n_out, phase_train, scope='bn'):
"""
Batch normalization on convolutional maps.
Args:
x: Tensor, 4D BHWD input maps
n_out: integer, depth of input maps
phase_train: boolean tf.Varialbe, true indicates training phase
scope: string, variable scope
Return:
normed: batch-normalized maps
"""
with tf.variable_scope(scope):
beta = tf.Variable(tf.constant(0.0, shape=[n_out]),
name='beta', trainable=True)
gamma = tf.Variable(tf.constant(1.0, shape=[n_out]),
name='gamma', trainable=True)
batch_mean, batch_var = tf.nn.moments(x, [0,1,2], name='moments')
ema = tf.train.ExponentialMovingAverage(decay=0.5)
def mean_var_with_update():
ema_apply_op = ema.apply([batch_mean, batch_var])
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean), tf.identity(batch_var)
mean, var = tf.cond(phase_train,
mean_var_with_update,
lambda: (ema.average(batch_mean), ema.average(batch_var)))
normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-3)
return normed
小batch size下BN tricks
1. 增大BN統計的範圍
f = f.reshape([N, H, W * G, C//G])
f = BN(f) # standard BN
f = f.reshape([N, H, W, C])
BN爲每個channel單獨算一個mean和var;這種BN trick的思路是爲每個channel group計算一個mean和var,和GN的motivation有點像。在batch size較小的時候相當於強行增大了BN統計的範圍(從N*H*W增大到了N*H*W*G),使BN統計更爲穩定, 可以用於解決小batch size的副作用。
2. Synchronized BN
常規的BN是在每一個GPU上單獨計算mean和var, 但Synchronized BN 跨卡計算mean和var, 減少batch size的副作用。