tensorflow estimator 訓練完的模型與驗證時載入的模型是否一致

estimator有個問題就是驗證時是從文件中載入模型的,這樣存在問題是無法保證從保存到載入期間的完全正確性。對於這種問題,我們一般採用少量數據,然後在訓練集上進行驗證。確認預測數據是否一致。主要是使用cond的控制最後幾輪不進行訓練,並且把數據打印出來。

            def train_func():

                # 構建訓練節點
                train_op = create_optimizer(
                    total_loss, lr, optimizer_params, 1., variables_to_train, use_fp16=FLAGS.use_fp16)

                update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                update_ops.append(train_op)
                update_op = tf.group(*update_ops)
                with tf.control_dependencies([update_op]):
                    train_tensor = tf.identity(total_loss)
                return train_tensor


            train_tensor = tf.cond(global_step_tensor>5, lambda : tf.Print(total_loss,
                            [tf.shape(per_example_loss), per_example_loss,
                             tf.reduce_mean(per_example_loss),
                             tf.reduce_mean(tf.to_float(per_example_loss<FLAGS.margin))], summarize=32), lambda : train_func() )

tf.cond and tf.case execute all branches
tensorflow: Initializer for variable… is from inside a control-flow construct, a loop or conditional

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章