tf.train.MonitoredSession 簡介

在run過程中的集成一些操作,比如輸出log,保存,summary 等


基類一般用在infer階段,訓練階段使用它的子類
tf.train.MonitoredTrainingSession

1 MonitoredTrainingSession

1.1 構造函數

MonitoredTrainingSession(
    master='',
    is_chief=True,
    checkpoint_dir=None,
    scaffold=None,
    hooks=None,
    chief_only_hooks=None,
    save_checkpoint_secs=600,
    save_summaries_steps=USE_DEFAULT,
    save_summaries_secs=USE_DEFAULT,
    config=None,
    stop_grace_period_secs=120,
    log_step_count_steps=100,
    max_wait_secs=7200
)

官方例子

saver_hook = CheckpointSaverHook(...)
summary_hook = SummarySaverHook(...)
with MonitoredSession(session_creator=ChiefSessionCreator(...),
                      hooks=[saver_hook, summary_hook]) as sess:
  while not sess.should_stop():
    sess.run(train_op)

首先,當MonitoredSession初始化的時候,會按順序執行下面操作:

  • 調用hook的begin()函數,我們一般在這裏進行一些hook內的初始化。比如在上面貓狗大戰中的_LoggerHook裏面的_step屬性,就是用來記錄執行步驟的,但是該參數只在本類中起作用。
  • 通過調用scaffold.finalize()初始化計算圖
    創建會話
  • 通過初始化Scaffold提供的操作(op)來初始化模型
  • 如果checkpoint存在的話,restore模型的參數
  • launches queue runners
  • 調用hook.after_create_session()

然後,當run()函數運行的時候,按順序執行下列操作:

  • 調用hook.before_run()
  • 調用TensorFlow的 session.run()
  • 調用hook.after_run()
  • 返回用戶需要的session.run()的結果
  • 如果發生了AbortedError或者UnavailableError,則在再次執行run()之前恢復或者重新初始化會話

最後,當調用close()退出時,按順序執行下列操作:

  • 調用hook.end()
  • 關閉隊列和會話
  • 阻止OutOfRange錯誤

1.2 Hook

所以這些鉤子函數就是重點關注的對象

.1 LoggingTensorHook

tf.train.LoggingTensorHook 官方說明

Prints the given tensors every N local steps, every N seconds, or at end.

__init__(
    tensors,
    every_n_iter=None,
    every_n_secs=None,
    formatter=None
)
  • tensors: dict that maps string-valued tags to tensors/tensor names, or iterable of tensors/tensor names.

用法舉例

# Set up logging for predictions
  tensors_to_log = {"probabilities": "softmax_tensor"}
  logging_hook = tf.train.LoggingTensorHook(
      tensors=tensors_to_log, every_n_iter=50)

.2 SummarySaverHook

tf.train.SummarySaverHook

Saves summaries every N steps

__init__(
    save_steps=None,
    save_secs=None,
    output_dir=None,
    summary_writer=None,
    scaffold=None,
    summary_op=None
)

output_dir 填 路徑
summary_op 填 tf.summary.merge_all

.3 CheckpointSaverHook

tf.train.CheckpointSaverHook
MonitoredTrainingSession 只有 save_checkpoint_secs, 沒有按step保存的選項
* Saves checkpoints every N steps or seconds

__init__(
    checkpoint_dir,
    save_secs=None,
    save_steps=None,
    saver=None,
    checkpoint_basename='model.ckpt',
    scaffold=None,
    listeners=None
)

必填 saver, save_secs 或者 save_steps

.4 NanTensorHook

tf.train.NanTensorHook
感覺是用來調試的,加到訓練過程中可能會拖慢train

  • Monitors the loss tensor and stops training if loss is NaN.
    Can either fail with exception or just stop training.
__init__(
    loss_tensor,
    fail_on_nan_loss=True
)

.5 FeedFnHook

tf.train.FeedFnHook
看着像用來產生 feed_dict

Runs feed_fn and sets the feed_dict accordingly

__init__(feed_fn)

.6 GlobalStepWaiterHook

tf.train.GlobalStepWaiterHook
分佈式用

.7 ProfilerHook

tf.train.ProfilerHook

This hook delays execution until global step reaches to wait_until_step. It is used to gradually start workers in distributed settings. One example usage would be setting wait_until_step=int(K*log(task_id+1)) assuming that task_id=0 is the chief

reference

tf.train.MonitoredSession
https://www.tensorflow.org/versions/master/api_docs/python/tf/train/MonitoredSession
resnet_main.py
https://github.com/tensorflow/models/blob/master/research/resnet/resnet_main.py
tf.train.MonitoredTrainingSession
https://www.tensorflow.org/versions/master/api_docs/python/tf/train/MonitoredTrainingSession
使用自己的數據集進行一次完整的TensorFlow訓練
https://zhuanlan.zhihu.com/p/32490882
tf.train.LoggingTensorHook
https://www.tensorflow.org/api_docs/python/tf/train/LoggingTensorHook
tf.train.SummarySaverHook
https://www.tensorflow.org/versions/master/api_docs/python/tf/train/SummarySaverHook
tf.train.CheckpointSaverHook
https://www.tensorflow.org/versions/master/api_docs/python/tf/train/CheckpointSaverHook
tf.train.NanTensorHook
https://www.tensorflow.org/versions/master/api_docs/python/tf/train/NanTensorHook#__init__

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章