TensorFlow Estimator 官方文檔之----Checkpoints

本文介紹了 Estimators 模型的保存和恢復。

TensorFlow提供了兩種模型格式:

  • checkpoints:這種格式依賴於創建模型的代碼。
  • SavedModel:這種格式與創建模型的代碼無關。

本文檔主要介紹checkpoints。要詳細瞭解 SavedModel,請參閱《TensorFlow 編程人員指南》的 Saving and Restoring 一章。

1. 保存經過部分訓練的模型

Estimators 在訓練過程中會自動將以下內容保存到磁盤:

  • chenkpoints:訓練過程中的模型快照。
  • event files:其中包含 TensorBoard 用於創建可視化圖表的信息。

通過 model_dir 參數,我們可以指定 Estimator 保存上述文件時的頂級目錄。

# 實例化 estimator
classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[10, 10],
    n_classes=3,
    model_dir='models/iris')

# 訓練 estimator
classifier.train(
    input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
    steps=200)

如下圖所示,第一次調用 train 方法會將 checkpoints 和 event files 文件添加到 model_dir 目錄中。
在這裏插入圖片描述
在類Unix系統中,可以使用 ls 命令來查看 model_dir 目錄中的內容。

$ ls -l models/iris
checkpoint
events.out.tfevents.timestamp.hostname
graph.pbtxt
model.ckpt-1.data-00000-of-00001
model.ckpt-1.index
model.ckpt-1.meta
model.ckpt-200.data-00000-of-00001
model.ckpt-200.index
model.ckpt-200.meta

通過 ls 命令,我們可以看到,Estimator在step 1(訓練開始)和step 200(訓練結束)創建了 checkpoints 文件。

1.1 model_dir 的默認目錄

如果你沒有指定 model_dir 參數,Estimator 會將 checkpoints 文件保存到一個臨時文件夾。Python的 tempfile.mkdtemp 函數會根據您的操作系統選擇安全的臨時目錄。

classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[10, 10],
    n_classes=3)

print(classifier.model_dir) # 查看臨時目錄

1.2 創建 Checkpoints 的頻率

默認情況下,Estimator 會根據以下策略來寫入 checkpoints。

  • 每10分鐘(600秒)向磁盤寫入一個 checkpoint。
  • train 方法開始(第一次迭代)和結束(最後一次迭代)時寫入一個 checkpoint。
  • model_dir 目錄中保留 5 個最近寫入的檢查點。

當然,你可以按如下方式修改 checkpoint 的寫入策略:

  1. 創建一個tf.estimator.RunConfig對象來定義 checkpoint 寫入策略。
  2. 在實例化 Estimator 時,將 RunConfig 對象傳給 Estimator 的 config 參數。
下面的代碼將 checkpoint 寫入間隔設置爲20分鐘,並且保留最近的10個 checkpoints:
est_config = tf.estimator.RunConfig(
    save_checkpoints_secs = 20*60,  # 每20分鐘保存一次 checkpoints
    keep_checkpoint_max = 10,       # 保留最新的10個checkpoints
)

classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[10, 10],
    n_classes=3,
    model_dir='models/iris',
    config=est_config)

2. 恢復模型

第一次調用 Estimator 的 train 方法時,TensorFlow會保存 checkpoint 文件到 model_dir 目錄。隨後調用 tarinevaluatepredict 方法將進行如下操作:

  1. Estimator 通過運行 model_fn 來構建模型的計算圖。
  2. Estimator 從 checkpoints 中初始化模型參數。

在這裏插入圖片描述

2.1 避免不當恢復

僅在模型和checkpoint兼容的情況下,才能從 checkpoint 恢復模型的狀態。例如,假設您訓練了DNNClassifier包含兩個隱藏層的 Estimator,每個隱藏層有10個節點:

classifier = tf.estimator.DNNClassifier(
    feature_columns=feature_columns,
    hidden_units=[10, 10],
    n_classes=3,
    model_dir='models/iris')

classifier.train(
    input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
        steps=200)

訓練之後,如果您將每個隱藏層中的神經元數量從10更改爲20,並嘗試重新訓練模型:

classifier2 = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[20, 20],  # Change the number of neurons in the model.
    n_classes=3,
    model_dir='models/iris')

classifier2.train(
    input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
        steps=200)

由於 checkpoint 中的狀態與描述的模型不兼容,因此重新訓練失敗並出現以下錯誤:

...
InvalidArgumentError (see above for traceback): tensor_name =
dnn/hiddenlayer_1/bias/t_0/Adagrad; shape in shape_and_slice spec [10]
does not match the shape stored in checkpoint: [20]

如果要訓練並比較略微不同的模型版本,請爲每一個模型版本(code and model_dir)創建一個單獨的文件夾。

總結

Estimator 對於模型的保存和恢復有着良好的支持。

如果您想了解以下內容,請查看Saving and Restoring

  • 使用 TensorFlow 的低階 API 來保存恢復模型。
  • 使用 SavedModel 格式(與創建模型的代碼無關)來保存、恢復模型。

本文的代碼來源

本文的絕大多數代碼都來自於 premade_estimator.py,部分進行了小的修改。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章