TensorFlow踩坑指南

BatchNormalization

1號坑

  1. tf.nn.batch_normalization(),tf.layers.batch_normalizationtensorflow.contrib.layers.batch_norm(),這三個batch normal函數的封裝程度逐漸遞增。這三個函數會自動將 update_ops 添加到tf.GraphKeys.UPDATE_OPS這個collection中。

  2. tf.keras.layers.BatchNormalization 不會自動將 update_ops 添加到tf.GraphKeys.UPDATE_OPS 這個collection中。所以在 TensorFlow 訓練 session 中使用 tf.keras.layers.BatchNormalization 時,需要手動將 keras.BatchNormalization 層的 updates 添加到tf.GraphKeys.UPDATE_OPS 這個 collection 中。

x = tf.placeholder("float",[None,32,32,3])
bn1 = tf.keras.layers.BatchNormalization()
y = bn1(x, training=True) # 調用後updates屬性纔會有內容。
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, bn1.updates)
  1. Batch Normalization 中需要計算移動平均值,所以 BN 中有一些 update_ops,在訓練中需要通過tf.control_dependencies()來添加對update_ops 的調用:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        self.train_op = tf.train.AdamOptimizer(learning_rate).minimize(self.loss, global_step=self.global_step)

參考

[1] https://blog.csdn.net/u014061630/article/details/85104491

2號坑

使用Batch Normalization的卷積神經網絡,當在驗證階段,將is_training設置爲False之後,loss會爆炸式增長。
1.測試階段代碼:

_, loss, acc = self.sess.run([self.model.train_op, self.model.loss, self.model.acc],
                              feed_dict={self.model.x: x, self.model.y: y, self.model.is_training: True})

1.驗證階段代碼:

loss, acc = self.sess.run([self.model.loss, self.model.acc],
                           feed_dict={self.model.x: x, self.model.y: y, self.model.is_training: False})

在解決loss爆炸和測試階段與訓練階段loss和accuracy差異巨大的時候,參考1效果明顯。
在這裏插入圖片描述

參考

[1] https://stackoverflow.com/questions/47953242/tensorflow-batch-normalization-tf-contrib-layers-batch-norm
[2] https://arxiv.org/pdf/1711.00489.pdf

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章