1、概述
和老版本的tensorflow一樣,模型需要進行保存,而且這種保存方式是週期性的。因爲在很多情況下,梯度會在局部最小值左右進行搖擺,也就是說,在很多情況下,最後一次訓練的模型不見得是最優化的。
2、保存模型
我們可以在構建模型時,制定檢查點保存的位置,首先我們可以用下面命令創建一個文件夾。
可以在構建模型時加入參數
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[10, 10],
n_classes=3,
# 指定模型(檢查點存放位置)
model_dir='./checkpoint')
在訓練模型之後可以在相應的文件夾看到下面的文件:
默認情況下,只保存了第一次和最後一次的模型,另外除了在改文件夾中保存模型外,事件文件也被記錄在該文件夾中。這與老版本的tensorflow有些不同。
- 檢查點:訓練期間所創建的模型版本。
- 事件文件:其中包含 TensorBoard 用於創建可視化圖表的信息。
3、配置模型保存參數
默認情況下,Estimator 按照以下時間安排將檢查點保存到 model_dir
中:
- 每 10 分鐘(600 秒)寫入一個檢查點。
- 在
train
方法開始(第一次迭代)和完成(最後一次迭代)時寫入一個檢查點。 - 只在目錄中保留 5 個最近寫入的檢查點。
可以自定義配置文件,來對檢查點的保存方式進行修改
model_save_config = tf.estimator.RunConfig(save_checkpoints_steps=500, # 制定每多少步保存一次模型
keep_checkpoint_max=6) # 指定最多保存多少個模型
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[10, 10],
n_classes=3,
# 指定模型(檢查點存放位置)
model_dir='./checkpoint',
config=model_save_config)
修改之後模型保存情況如下圖所示:
4、恢復模型
第一次調用 Estimator 的 train
方法時,TensorFlow 會將一個檢查點保存到 model_dir
中。隨後每次調用 Estimator 的 train
、eval
或 predict
方法時,都會發生下列情況:
- Estimator 通過運行
model_fn()
構建模型圖。(要詳細瞭解model_fn()
,請參閱創建自定義 Estimator。) - Estimator 根據最近寫入的檢查點中存儲的數據來初始化新模型的權重。
換言之,如下圖所示,一旦存在檢查點,TensorFlow 就會在您每次調用 train()
、evaluate()
或 predict()
時重建模型。
注意: 如果在訓練之後,更改模型結構再重新進行訓練時,則原檢查點的內容不兼容系統報錯