estimator有個問題就是驗證時是從文件中載入模型的,這樣存在問題是無法保證從保存到載入期間的完全正確性。對於這種問題,我們一般採用少量數據,然後在訓練集上進行驗證。確認預測數據是否一致。主要是使用cond的控制最後幾輪不進行訓練,並且把數據打印出來。
def train_func():
# 構建訓練節點
train_op = create_optimizer(
total_loss, lr, optimizer_params, 1., variables_to_train, use_fp16=FLAGS.use_fp16)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
update_ops.append(train_op)
update_op = tf.group(*update_ops)
with tf.control_dependencies([update_op]):
train_tensor = tf.identity(total_loss)
return train_tensor
train_tensor = tf.cond(global_step_tensor>5, lambda : tf.Print(total_loss,
[tf.shape(per_example_loss), per_example_loss,
tf.reduce_mean(per_example_loss),
tf.reduce_mean(tf.to_float(per_example_loss<FLAGS.margin))], summarize=32), lambda : train_func() )
tf.cond and tf.case execute all branches
tensorflow: Initializer for variable… is from inside a control-flow construct, a loop or conditional