第二章:新版tensorflow入門,使用檢查點保存模型

1、概述

和老版本的tensorflow一樣,模型需要進行保存,而且這種保存方式是週期性的。因爲在很多情況下,梯度會在局部最小值左右進行搖擺,也就是說,在很多情況下,最後一次訓練的模型不見得是最優化的。

2、保存模型

我們可以在構建模型時,制定檢查點保存的位置,首先我們可以用下面命令創建一個文件夾。


可以在構建模型時加入參數

classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[10, 10],
    n_classes=3,
#   指定模型(檢查點存放位置)  
    model_dir='./checkpoint')

在訓練模型之後可以在相應的文件夾看到下面的文件:


默認情況下,只保存了第一次和最後一次的模型,另外除了在改文件夾中保存模型外,事件文件也被記錄在該文件夾中。這與老版本的tensorflow有些不同。

  • 檢查點:訓練期間所創建的模型版本。
  • 事件文件:其中包含 TensorBoard 用於創建可視化圖表的信息。

3、配置模型保存參數

默認情況下,Estimator 按照以下時間安排將檢查點保存到 model_dir 中:

  • 每 10 分鐘(600 秒)寫入一個檢查點。
  • 在 train 方法開始(第一次迭代)和完成(最後一次迭代)時寫入一個檢查點。
  • 只在目錄中保留 5 個最近寫入的檢查點。

可以自定義配置文件,來對檢查點的保存方式進行修改

model_save_config = tf.estimator.RunConfig(save_checkpoints_steps=500, # 制定每多少步保存一次模型
                                           keep_checkpoint_max=6) # 指定最多保存多少個模型
classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[10, 10],
    n_classes=3,
#   指定模型(檢查點存放位置)  
    model_dir='./checkpoint',
    config=model_save_config)

修改之後模型保存情況如下圖所示:


4、恢復模型

第一次調用 Estimator 的 train 方法時,TensorFlow 會將一個檢查點保存到 model_dir 中。隨後每次調用 Estimator 的 traineval 或 predict 方法時,都會發生下列情況:

  1. Estimator 通過運行 model_fn() 構建模型。(要詳細瞭解 model_fn(),請參閱創建自定義 Estimator。)
  2. Estimator 根據最近寫入的檢查點中存儲的數據來初始化新模型的權重。

換言之,如下圖所示,一旦存在檢查點,TensorFlow 就會在您每次調用 train()evaluate() 或 predict() 時重建模型。



注意: 如果在訓練之後,更改模型結構再重新進行訓練時,則原檢查點的內容不兼容系統報錯



發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章