本文介紹了 Estimators 模型的保存和恢復。
TensorFlow提供了兩種模型格式:
- checkpoints:這種格式依賴於創建模型的代碼。
- SavedModel:這種格式與創建模型的代碼無關。
本文檔主要介紹checkpoints。要詳細瞭解 SavedModel
,請參閱《TensorFlow 編程人員指南》的 Saving and Restoring 一章。
1. 保存經過部分訓練的模型
Estimators 在訓練過程中會自動將以下內容保存到磁盤:
- chenkpoints:訓練過程中的模型快照。
- event files:其中包含 TensorBoard 用於創建可視化圖表的信息。
通過 model_dir
參數,我們可以指定 Estimator 保存上述文件時的頂級目錄。
# 實例化 estimator
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[10, 10],
n_classes=3,
model_dir='models/iris')
# 訓練 estimator
classifier.train(
input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
steps=200)
如下圖所示,第一次調用 train
方法會將 checkpoints 和 event files 文件添加到 model_dir
目錄中。
在類Unix系統中,可以使用 ls
命令來查看 model_dir
目錄中的內容。
$ ls -l models/iris
checkpoint
events.out.tfevents.timestamp.hostname
graph.pbtxt
model.ckpt-1.data-00000-of-00001
model.ckpt-1.index
model.ckpt-1.meta
model.ckpt-200.data-00000-of-00001
model.ckpt-200.index
model.ckpt-200.meta
通過 ls
命令,我們可以看到,Estimator在step 1(訓練開始)和step 200(訓練結束)創建了 checkpoints 文件。
1.1 model_dir 的默認目錄
如果你沒有指定 model_dir
參數,Estimator 會將 checkpoints 文件保存到一個臨時文件夾。Python的 tempfile.mkdtemp 函數會根據您的操作系統選擇安全的臨時目錄。
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[10, 10],
n_classes=3)
print(classifier.model_dir) # 查看臨時目錄
1.2 創建 Checkpoints 的頻率
默認情況下,Estimator 會根據以下策略來寫入 checkpoints。
- 每10分鐘(600秒)向磁盤寫入一個 checkpoint。
- 在
train
方法開始(第一次迭代)和結束(最後一次迭代)時寫入一個 checkpoint。 model_dir
目錄中保留 5 個最近寫入的檢查點。
當然,你可以按如下方式修改 checkpoint 的寫入策略:
- 創建一個
tf.estimator.RunConfig
對象來定義 checkpoint 寫入策略。 - 在實例化 Estimator 時,將 RunConfig 對象傳給 Estimator 的 config 參數。
est_config = tf.estimator.RunConfig(
save_checkpoints_secs = 20*60, # 每20分鐘保存一次 checkpoints
keep_checkpoint_max = 10, # 保留最新的10個checkpoints
)
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[10, 10],
n_classes=3,
model_dir='models/iris',
config=est_config)
2. 恢復模型
第一次調用 Estimator 的 train
方法時,TensorFlow會保存 checkpoint 文件到 model_dir 目錄。隨後調用 tarin
、evaluate
、predict
方法將進行如下操作:
- Estimator 通過運行
model_fn
來構建模型的計算圖。 - Estimator 從 checkpoints 中初始化模型參數。
2.1 避免不當恢復
僅在模型和checkpoint兼容的情況下,才能從 checkpoint 恢復模型的狀態。例如,假設您訓練了DNNClassifier包含兩個隱藏層的 Estimator,每個隱藏層有10個節點:
classifier = tf.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[10, 10],
n_classes=3,
model_dir='models/iris')
classifier.train(
input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
steps=200)
訓練之後,如果您將每個隱藏層中的神經元數量從10更改爲20,並嘗試重新訓練模型:
classifier2 = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[20, 20], # Change the number of neurons in the model.
n_classes=3,
model_dir='models/iris')
classifier2.train(
input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
steps=200)
由於 checkpoint 中的狀態與描述的模型不兼容,因此重新訓練失敗並出現以下錯誤:
...
InvalidArgumentError (see above for traceback): tensor_name =
dnn/hiddenlayer_1/bias/t_0/Adagrad; shape in shape_and_slice spec [10]
does not match the shape stored in checkpoint: [20]
如果要訓練並比較略微不同的模型版本,請爲每一個模型版本(code and model_dir)創建一個單獨的文件夾。
總結
Estimator 對於模型的保存和恢復有着良好的支持。
如果您想了解以下內容,請查看Saving and Restoring:
- 使用 TensorFlow 的低階 API 來保存恢復模型。
- 使用 SavedModel 格式(與創建模型的代碼無關)來保存、恢復模型。
本文的代碼來源
本文的絕大多數代碼都來自於 premade_estimator.py,部分進行了小的修改。