【Tensorflow】用於構建大規模分佈式模型的高階API(四)_Estimator模型的保存

import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn import datasets
from sklearn.model_selection import train_test_split
iris = datasets.load_iris()
train_x,test_x,train_y,test_y= train_test_split(iris.data,iris.target,test_size=0.3,random_state=0)

def train_input_fn(features, labels, batch_size):
    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
    return dataset.shuffle(1000).repeat().batch(batch_size) # 由於repeat()未指定參數,因此會一直循環下去

def eval_input_fn(features, labels, batch_size):
    features=dict(features)
    if labels is None:
        inputs = features
    else:
        inputs = (features, labels)
    dataset = tf.data.Dataset.from_tensor_slices(inputs)
    assert batch_size is not None, "batch_size must not be None"
    dataset = dataset.batch(batch_size) # 沒有repeat(),如果只會遍歷數據集一次
    return dataset

feature_names = ['SepalLength','SepalWidth','PetalLength','PetalWidth']
my_feature_columns = []
for key in feature_names:
    my_feature_columns.append(tf.feature_column.numeric_column(key=key))

一、保存訓練過程中的模型(checkpoints)

Estimator會自動將以下內容寫入磁盤:

檢查點(checkpoint):訓練過程中的模型版本;

事件文件(event):用於TensorBorad繪製可視化圖的信息;

如果是使Estimator在指定的目錄中保存上面的信息,需要在Estimator的構造函數中指定參數model_dir的值,在train時就會自動保存。例如

classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[10, 10],
    n_classes=3,
    model_dir='models/iris')
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': 'models/iris', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f681ca55278>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

二、默認checkpoint目錄

如果沒有在構造Estimator時指定model_dir,那麼Estimator會將checkpoint文件寫入由Python的tempfile.mkdtemp函數選擇的臨時目錄中。

三、控制checkpoint的行爲

1.默認checkpoint的保存行爲

每10分鐘(600 秒)寫入一個checkpoint;

在train方法開始(第一次迭代)和完成(最後一次迭代)時寫入一個checkpoint;

只在目錄中保留5個最近寫入的checkpoint;

2.修改默認checkpoint的保存行爲

創建一個RunConfig對象來定義所需的時間安排;

在實例化Estimator時,將該RunConfig對象傳遞給Estimator的config參數;

my_checkpointing_config = tf.estimator.RunConfig(
    save_checkpoints_secs = 20*60,  # 每20分鐘保存一次checkpoint
    keep_checkpoint_max = 10,       # 保存10個最近的checkpoints
)

classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[10, 10],
    n_classes=3,
    model_dir='models/iris',
    config=my_checkpointing_config)
INFO:tensorflow:Using config: {'_model_dir': 'models/iris', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 1200, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 10, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f681ca55b38>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

四、從checkpoint中恢復模型

第一次調用Estimator的train方法時,TensorFlow會將一個檢查點保存到model_dir中。隨後每次調用Estimator的train、evaluate 或 predict 方法時,都會發生下列情況:

1.Estimator 通過運行 model_fn() 構建模型圖。

2.Estimator 根據最近寫入的檢查點中存儲的數據來初始化新模型的權重。

一旦存在檢查點,TensorFlow 就會在您每次調用 train()、evaluate() 或 predict() 時重建模型。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章