訓練代碼中使用了tf.train.SessionRunHook(),tf.train.MonitoredTrainingSession(),查看官方API後可以知道
一. tf.train.MonitoredTrainingSession()
首先,tf.train.MonitorSession()從單詞的字面意思理解是用於監控訓練的回話,返回值是tf.train.MonitorSession()類的一個實例Object, tf.train.MonitorSession()會在下面講。
tf.train.MonitoredTrainingSession(
master='',
is_chief=True,
checkpoint_dir=None,
scaffold=None,
hooks=None,
chief_only_hooks=None,
save_checkpoint_secs=600,
save_summaries_steps=USE_DEFAULT,
save_summaries_secs=USE_DEFAULT,
config=None,
stop_grace_period_secs=120,
log_step_count_steps=100
)
Args:
- is_chief:用於分佈式系統中,用於判斷該系統是否是chief,如果爲True,它將負責初始化並恢復底層TensorFlow會話。如果爲False,它將等待chief初始化或恢復TensorFlow會話。
- checkpoint_dir:一個字符串。指定一個用於恢復變量的checkpoint文件路徑。
- scaffold:用於收集或建立支持性操作的腳手架。如果未指定,則會創建默認一個默認的scaffold。它用於完成圖表
- hooks:SessionRunHook對象的可選列表。可自己定義SessionRunHook對象,也可用已經預定義好的SessionRunHook對象,如:tf.train.StopAtStepHook()設置停止訓練的條件;tf.train.NanTensorHook(loss):如果loss的值爲Nan則停止訓練;
- chief_only_hooks:SessionRunHook對象列表。如果is_chief== True,則激活這些掛鉤,否則忽略。
- save_checkpoint_secs:用默認的checkpoint saver保存checkpoint的頻率(以秒爲單位)。如果save_checkpoint_secs設置爲None,不保存checkpoint。
- save_summaries_steps:使用默認summaries saver將摘要寫入磁盤的頻率(以全局步數表示)。如果save_summaries_steps和save_summaries_secs都設置爲None,則不使用默認的summaries saver保存summaries。默認爲100
- save_summaries_secs:使用默認summaries saver將摘要寫入磁盤的頻率(以秒爲單位)。如果save_summaries_steps和save_summaries_secs都設置爲None,則不使用默認的摘要保存。默認未啓用。
- config:用於配置會話的tf.ConfigProtoproto的實例。它是tf.Session的構造函數的config參數。
- stop_grace_period_secs:調用close()後線程停止的秒數。
- log_step_count_steps:記錄全局步/秒的全局步數的頻率
Returns:
一個MonitoredSession() 實例
二. tf.train.MonitoredSession()類
官方文檔給的定義是:
Session-like object that handles initialization, recovery and hooks
是一個處理初始化,模型恢復,和處理Hooks的類似與Session的類。
Args:
- session_creator:制定用於創建回話的ChiefSessionCreator
- hooks:tf.train.SessionRunHook()實例的列表
Returns: 一個MonitoredSession 實例
Example usage:
saver_hook = tf.train.CheckpointSaverHook(...)
summary_hook = tf.train.SummarySaverHook(...)
with tf.train.MonitoredSession(session_creator=ChiefSessionCreator(...),
hooks=[saver_hook, summary_hook]) as sess:
while not sess.should_stop():
sess.run(train_op)
初始化:在創建一個MonitoredSession時,會按順序執行以下操作:
- 調用[Hooks]列表中每一個Hook的begin()函數
- 通過scaffold.finalize()完成圖graph的定義
- 創建會話
- 用Scaffold提供的初始化操作(op)來初始化模型
- 如果給定checkpoint_dir中存在checkpoint文件,則用checkpoint恢復變量
- 啓動隊列線程
- 調用hook.after_create_session()
Run:當調用run()函數時,按順序執行以下操作
- 調用hook.before_run()
- 用合併後的fetches 和feed_dict調用TensorFlow的session.run() (這裏是真正調用tf.Session().run(fetches ,feed_dict))
- 調用hook.after_run()
- 返回用戶需要的session.run()的結果
- 如果發生了AbortedError或者UnavailableError,則在再次執行run()之前恢復或者重新初始化會話
Exit:當調用close()退出時,按順序執行下列操作
- 調用hook.end()
- 關閉隊列線程queuerunners和會話session
- 在monitored_session的上下文中,抑制由於處理完所有輸入拋出的OutOf Range錯誤。
參考:
https://blog.csdn.net/MrR1ght/article/details/81006343