BatchNormalization
1號坑
-
tf.nn.batch_normalization()
,tf.layers.batch_normalization
和tensorflow.contrib.layers.batch_norm()
,這三個batch normal函數的封裝程度逐漸遞增。這三個函數會自動將update_ops
添加到tf.GraphKeys.UPDATE_OPS
這個collection
中。 -
tf.keras.layers.BatchNormalization
不會自動將update_ops
添加到tf.GraphKeys.UPDATE_OPS
這個collection
中。所以在 TensorFlow 訓練session
中使用tf.keras.layers.BatchNormalization
時,需要手動將keras.BatchNormalization
層的updates
添加到tf.GraphKeys.UPDATE_OPS
這個collection
中。
x = tf.placeholder("float",[None,32,32,3])
bn1 = tf.keras.layers.BatchNormalization()
y = bn1(x, training=True) # 調用後updates屬性纔會有內容。
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, bn1.updates)
- Batch Normalization 中需要計算移動平均值,所以 BN 中有一些
update_ops
,在訓練中需要通過tf.control_dependencies()
來添加對update_ops
的調用:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
self.train_op = tf.train.AdamOptimizer(learning_rate).minimize(self.loss, global_step=self.global_step)
參考
[1] https://blog.csdn.net/u014061630/article/details/85104491
2號坑
使用Batch Normalization的卷積神經網絡,當在驗證階段,將is_training
設置爲False
之後,loss
會爆炸式增長。
1.測試階段代碼:
_, loss, acc = self.sess.run([self.model.train_op, self.model.loss, self.model.acc],
feed_dict={self.model.x: x, self.model.y: y, self.model.is_training: True})
1.驗證階段代碼:
loss, acc = self.sess.run([self.model.loss, self.model.acc],
feed_dict={self.model.x: x, self.model.y: y, self.model.is_training: False})
在解決loss爆炸和測試階段與訓練階段loss和accuracy差異巨大的時候,參考1
效果明顯。
參考
[1] https://stackoverflow.com/questions/47953242/tensorflow-batch-normalization-tf-contrib-layers-batch-norm
[2] https://arxiv.org/pdf/1711.00489.pdf