爲什麼會出現Batch Normalization層

訓練模型時的收斂速度問題

衆所周知,模型訓練需要使用高性能的GPU,還要花費大量的訓練時間。除了數據量大及模型複雜等硬性因素外,數據分佈的不斷變化使得我們必須使用較小的學習率、較好的權重初值和不容易飽和的激活函數(如sigmoid,正負兩邊都會飽和)來訓練模型。這樣速度自然就慢了下來。

下面先簡單示例一下數據分佈的不斷變化爲什麼會帶來這些問題,如圖:
這裏寫圖片描述
我們使用Wx+b=0對小黃和小綠進行分類。由於數據點僅落在第一象限中很小的區域裏,那麼如果我們隨機初始化權重W,需要迭代很多次纔會得到有效的分割,這勢必會帶來求解速率的下降,並且容易遇到局部最優解。

如果我們遇到的僅僅是這麼一個簡單的問題,大家用屁股都能想出來應該怎麼做。均值歸一化呀,把數據挪到原點附近,這樣問題就解決了。但是在深度神經網絡裏,數據分佈的不斷變化至少來自兩方面:a) 每批訓練數據的分佈各不相同(batch梯度下降),那麼網絡就要在每次迭代都去學習適應不同的分佈。b) 網絡中某一層輸入數據的分佈發生改變,後面幾層就會被累積放大下去,這樣模型就需要去不斷適應學習新的數據分佈。論文中把網絡中間層在訓練過程中,參數不斷變化導致的各層輸入分佈的變化 稱爲 Internal Covariate Shift。

怎麼辦呢?這個時候我們可能就會想,如果在每一層輸入的時候,再加個預處理操作那該有多好啊。好,BatchNormalization來了。

怎麼加入歸一化BatchNormalization

  1. 這裏加入的BatchNormalization層可不想我們想象的那麼簡單,它是一個可學習、有參數的網絡層。爲什麼呢?如果我們直接對網絡某一層A的輸出數據做歸一化然後送入網絡下一層B,這樣會影響到本層所學習到的特徵。(比如網絡中某一層學習到的特徵分佈在S型激活函數的某一側,你強制把它歸一化到標準差爲1內。)怎麼辦?論文引入了可學習的參數γ、β(文中把他稱爲變換重構)來保留其學習到的特徵,公式如下:
    這裏寫圖片描述
    接着論文中進一步證明了這一切參數都是可以鏈式求導的(具體見論文),因此γ、β也就可以像權重W那樣不斷迭代優化啦。

  2. 另一個問題是BatchNormalization層加在哪?有兩個落戶地址:a) W x +b 之後,非線性激活函數之前 b) 非線性激活函數後。作者認爲前者效果要更好一點,給出的解釋是:前一個激活層是非線性輸出,其分佈很可能在訓練中變化,而Wx+b更可能“more Gaussian”。(當然也有人簡單嘗試了一下Post-activation batch normalization,效果也還不錯)

其他一切細節自己看論文吧。

BatchNormalization帶來的好處

  1. 可以選擇比較大的初始學習率,讓你的訓練速度飆漲。
  2. 移除或使用較低的drop out、L2正則項參數。
  3. b 可以忽略了,因爲b 的作用其實被 β 代替了。

在tensorflow中應用Batch Normalization

A GENTLE GUIDE TO USING BATCH NORMALIZATION IN TENSORFLOW 這篇文章介紹的特別好,並且末尾的GITHUB代碼裏對全連接網絡的No batch normalization、Standard batch normalization和Post-activation batch normalization三種方法進行了實現和對比。貼一張他的實驗結果圖:
這裏寫圖片描述
這裏我對其Standard batch normalization進行了簡單的修改,貼於此:

import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data


# define our typical fully-connected + batch normalization + nonlinearity set-up
def dense(x, size, scope):
    return tf.contrib.layers.fully_connected(x, size,
                                             activation_fn=None,
                                             scope=scope)


def dense_batch_relu(x, phase, scope):
    with tf.variable_scope(scope):
        h1 = tf.contrib.layers.fully_connected(x, 100,
                                               activation_fn=None,
                                               scope='dense')
        h2 = tf.contrib.layers.batch_norm(h1,
                                          center=True, scale=True,
                                          is_training=phase,
                                          scope='bn')
        return tf.nn.relu(h2, 'relu')


tf.reset_default_graph()
x = tf.placeholder('float32', (None, 784), name='x')
y = tf.placeholder('float32', (None, 10), name='y')
phase = tf.placeholder(tf.bool, name='phase')

h1 = dense_batch_relu(x, phase,'layer1')
h2 = dense_batch_relu(h1, phase, 'layer2')
logits = dense(h2, 10, 'logits')

with tf.name_scope('accuracy'):
    accuracy = tf.reduce_mean(tf.cast(
            tf.equal(tf.argmax(y, 1), tf.argmax(logits, 1)),
            'float32'))

with tf.name_scope('loss'):
    loss = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y))


def train(mnist):
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        # Ensures that we execute the update_ops before performing the train_step
        train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    history = []
    iterep = 500
    for i in range(iterep * 30):
        x_train, y_train = mnist.train.next_batch(100)
        sess.run(train_step,
                 feed_dict={'x:0': x_train,
                            'y:0': y_train,
                            'phase:0': 1})
        if (i + 1) % iterep == 0:
            epoch = (i + 1)/iterep
            tr = sess.run([loss, accuracy],
                          feed_dict={'x:0': mnist.train.images,
                                     'y:0': mnist.train.labels,
                                     'phase:0': 1})
            t = sess.run([loss, accuracy],
                         feed_dict={'x:0': mnist.test.images,
                                    'y:0': mnist.test.labels,
                                    'phase:0': 0})
            history += [[epoch] + tr + t]
            print(history[-1])
    return history


def main(argv=None):
    mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
    train(mnist)


if __name__ == '__main__':
    tf.app.run()

參考:
Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
深度學習中 Batch Normalization爲什麼效果好?
Why does batch normalization help?
深度學習(二十九)Batch Normalization 學習筆記
論文筆記-Batch Normalization
《Batch Normalization Accelerating Deep Network Training by Reducing Internal Covariate Shift》閱讀筆記與實現
Implementing Batch Normalization in Tensorflow
A GENTLE GUIDE TO USING BATCH NORMALIZATION IN TENSORFLOW

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章