Tensorflow:estimator訓練

學習流程:Estimator 封裝了對機器學習不同階段的控制,用戶無需不斷的爲新機器學習任務重複編寫訓練、評估、預測的代碼。可以專注於對網絡結構的控制。
數據導入:Estimator 的數據導入也是由 input_fn 獨立定義的。例如,用戶可以非常方便的只通過改變 input_fn 的定義,來使用相同的網絡結構學習不同的數據。
網絡結構:Estimator 的網絡結構是在 model_fn 中獨立定義的,用戶創建的任何網絡結構都可以在 Estimator 的控制下進行機器學習。這可以允許用戶很方便的使用別人定義好的 model_fn。model_fn模型函數必須要有features, mode兩個參數,可自己選擇加入labels(可以把label也放進features中)。最後要返回特定的tf.estimator.EstimatorSpec()。模型有三個階段都共用的正向傳播部分,和由mode值來控制返回不同tf.estimator.EstimatorSpec的三個分支。

 

 

訓練

輸出信息解析

[Tensorflow:模型訓練tensorflow.train]

在訓練或評估中利用Hook打印中間信息

hooks:如果不送值,則訓練過程中不會顯示字典中的數值。

steps:指定了訓練多少次,如果不送值,則訓練到dataset API遍歷完數據集爲止。

max_steps:指定了最大訓練次數。

# 在訓練或評估的循環中,每50次print出一次字典中的數值
tensors_to_log = {"probabilities": "softmax_tensor"}
logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log, every_n_iter=50)
mnist_classifier.train(input_fn=train_input_fn, hooks=[logging_hook])

 

early stopping

函數原型

tf.contrib.estimator.stop_if_no_increase_hook(
    estimator,
    metric_name,
    max_steps_without_increase,
    eval_dir=None,
    min_steps=0,
    run_every_secs=60,
    run_every_steps=None
)

'stop_if_no_decrease_hook'這個模塊在tf 1.10才加入。hook可以看作一個管理訓練過程的工具,比如說這裏就是設置提前終止的條件,變量loss在100000步以內沒有下降即終止,實際上更廣泛的用法是用在對測試集的f1值上。

參數

metric_name: str類型,比如loss或者accuracy. hook中的參數metric_name='acc'就是tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=metrics)中的eval_metric_ops,即tf模塊代碼中通過的for step, metrics in read_eval_metrics(eval_dir).items()得到的。但是訓練好checkpoint後,就不能改,需要刪除之前訓練好的模型,重新訓練。

max_steps_without_increase: int,如果沒有增加的最大長是多少,如果超過了這個最大步長metric還是沒有增加那麼就會停止。

eval_dir:默認是使用estimator.eval_dir目錄,用於存放評估的summary file。

min_steps:訓練的最小步長,如果訓練小於這個步長那麼永遠都不會停止。

run_every_secs和run_every_steps:表示多長時間獲得步長調用一次should_stop_fn。

示例

        metrics = {
            'acc': tf.metrics.accuracy(tf.argmax(labels), tf.argmax(pred_ids)),
            'precision': tf.metrics.precision(tf.argmax(labels), tf.argmax(pred_ids)),
            'precision_': tf_metrics.precision(tf.argmax(labels), tf.argmax(pred_ids), num_labels),
            'recall': tf.metrics.recall(tf.argmax(labels), tf.argmax(pred_ids)),
            'recall_': tf_metrics.recall(tf.argmax(labels), tf.argmax(pred_ids), num_labels),
            'f1_': tf_metrics.f1(tf.argmax(labels), tf.argmax(pred_ids), num_labels),
            'auc': tf.metrics.auc(labels, pred_ids),
        }

        for metric_name, op in metrics.items():
            tf.summary.scalar(metric_name, op[1])

        ''' train and evaluate '''
        if mode == tf.estimator.ModeKeys.EVAL:
            return tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=metrics)
        elif mode == tf.estimator.ModeKeys.TRAIN:
            train_op = tf.train.AdamOptimizer().minimize(loss=loss,
                                                         global_step=tf.train.get_or_create_global_step())

...

        hook = tf.contrib.estimator.stop_if_no_increase_hook(estimator, 'f1', max_steps_without_increase=1000,
                                                             min_steps=8000, run_every_secs=120) 
        train_spec = tf.estimator.TrainSpec(input_fn=train_inpf, hooks=[hook])

[簡書tf.estimate]

from: -柚子皮-

ref:

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章