學習流程:Estimator 封裝了對機器學習不同階段的控制,用戶無需不斷的爲新機器學習任務重複編寫訓練、評估、預測的代碼。可以專注於對網絡結構的控制。
數據導入:Estimator 的數據導入也是由 input_fn 獨立定義的。例如,用戶可以非常方便的只通過改變 input_fn 的定義,來使用相同的網絡結構學習不同的數據。
網絡結構:Estimator 的網絡結構是在 model_fn 中獨立定義的,用戶創建的任何網絡結構都可以在 Estimator 的控制下進行機器學習。這可以允許用戶很方便的使用別人定義好的 model_fn。model_fn模型函數必須要有features, mode兩個參數,可自己選擇加入labels(可以把label也放進features中)。最後要返回特定的tf.estimator.EstimatorSpec()。模型有三個階段都共用的正向傳播部分,和由mode值來控制返回不同tf.estimator.EstimatorSpec的三個分支。
訓練
輸出信息解析
[Tensorflow:模型訓練tensorflow.train]
在訓練或評估中利用Hook打印中間信息
hooks:如果不送值,則訓練過程中不會顯示字典中的數值。
steps:指定了訓練多少次,如果不送值,則訓練到dataset API遍歷完數據集爲止。
max_steps:指定了最大訓練次數。
# 在訓練或評估的循環中,每50次print出一次字典中的數值
tensors_to_log = {"probabilities": "softmax_tensor"}
logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log, every_n_iter=50)
mnist_classifier.train(input_fn=train_input_fn, hooks=[logging_hook])
early stopping
函數原型
tf.contrib.estimator.stop_if_no_increase_hook(
estimator,
metric_name,
max_steps_without_increase,
eval_dir=None,
min_steps=0,
run_every_secs=60,
run_every_steps=None
)
'stop_if_no_decrease_hook'這個模塊在tf 1.10才加入。hook可以看作一個管理訓練過程的工具,比如說這裏就是設置提前終止的條件,變量loss在100000步以內沒有下降即終止,實際上更廣泛的用法是用在對測試集的f1值上。
參數
metric_name: str類型,比如loss或者accuracy. hook中的參數metric_name='acc'就是tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=metrics)中的eval_metric_ops,即tf模塊代碼中通過的for step, metrics in read_eval_metrics(eval_dir).items()得到的。但是訓練好checkpoint後,就不能改,需要刪除之前訓練好的模型,重新訓練。
max_steps_without_increase: int,如果沒有增加的最大長是多少,如果超過了這個最大步長metric還是沒有增加那麼就會停止。
eval_dir:默認是使用estimator.eval_dir目錄,用於存放評估的summary file。
min_steps:訓練的最小步長,如果訓練小於這個步長那麼永遠都不會停止。
run_every_secs和run_every_steps:表示多長時間獲得步長調用一次should_stop_fn。
示例
metrics = {
'acc': tf.metrics.accuracy(tf.argmax(labels), tf.argmax(pred_ids)),
'precision': tf.metrics.precision(tf.argmax(labels), tf.argmax(pred_ids)),
'precision_': tf_metrics.precision(tf.argmax(labels), tf.argmax(pred_ids), num_labels),
'recall': tf.metrics.recall(tf.argmax(labels), tf.argmax(pred_ids)),
'recall_': tf_metrics.recall(tf.argmax(labels), tf.argmax(pred_ids), num_labels),
'f1_': tf_metrics.f1(tf.argmax(labels), tf.argmax(pred_ids), num_labels),
'auc': tf.metrics.auc(labels, pred_ids),
}
for metric_name, op in metrics.items():
tf.summary.scalar(metric_name, op[1])
''' train and evaluate '''
if mode == tf.estimator.ModeKeys.EVAL:
return tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=metrics)
elif mode == tf.estimator.ModeKeys.TRAIN:
train_op = tf.train.AdamOptimizer().minimize(loss=loss,
global_step=tf.train.get_or_create_global_step())
...
hook = tf.contrib.estimator.stop_if_no_increase_hook(estimator, 'f1', max_steps_without_increase=1000,
min_steps=8000, run_every_secs=120)
train_spec = tf.estimator.TrainSpec(input_fn=train_inpf, hooks=[hook])
[簡書tf.estimate]
from: -柚子皮-
ref: