抄襲一段:checkpoint是一個內部事件,該事件激活後會觸發數據庫寫進程將數據緩衝中的髒數據寫到數據文件中。
checkpoint主要2個作用:
保證數據庫的一致性
縮短實例恢復時間
通俗的講,checkpoint像word的自動保存一樣。
tensorflow模型包含 meta圖(網絡結構圖) 和 checkpoint文件(網絡結構裏的參數值,現已經被分拆爲3個文件)
即總的文件包含目錄爲:
model.data-00000-of-00001 保存變量值
model.index 保存 .data 和 .meta 文件對應關係
model.meta 結構圖
checkpoint 文本文件,記錄中間節點上保存的模型的名稱
import tensorflow as tf
import os
W = tf.Variable(tf.zeros([2, 1]), name="weights")
b = tf.Variable(0., name="bias")
def inference(X):
return tf.matmul(X, W) + b
def loss(X, Y):
Y_predicted = inference(X)
return tf.reduce_sum(tf.squared_difference(Y, Y_predicted))
def inputs():
weight_age = [[84, 46], [73, 20], [65, 52], [70, 30], [76, 57], [69, 25], [63, 28], [72, 36], [79, 57], [75, 44],
[27, 24], [89, 31], [65, 52], [57, 23], [59, 60], [69, 48], [60, 34], [79, 51], [75, 50], [82, 34],
[59, 46], [67, 23], [85, 37], [55, 40], [63, 30]]
blood_fat_content = [354, 190, 405, 263, 451, 302, 288, 385, 402, 365, 209, 290, 346, 254, 395, 434, 220, 374, 308,
220, 311, 181, 274, 303, 244]
return tf.to_float(weight_age), tf.to_float(blood_fat_content)
def train(total_loss):
learning_rate = 0.0000001
return tf.train.GradientDescentOptimizer(learning_rate).minimize(total_loss)
def evaluate(sess, X, Y):
print(sess.run(inference([[80., 25.]]))) #303
print(sess.run(inference([[65., 25.]]))) #256
【1】模型訓練:
with tf.Session() as sess:
X, Y = inputs()
#init = tf.global_variables_initializer()
total_loss = loss(X, Y)
train_op = train(total_loss)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
init = tf.global_variables_initializer()
sess.run(init)
training_steps = 10000
saver = tf.train.Saver()
for step in range(training_steps):
sess.run(train_op)
if step % 10 == 0:
print("loss", sess.run(total_loss))
if step % 1000 == 0:
saver.save(sess, r"E:\tf_project\練習\model_save_dir\my-model", global_step=step)
evaluate(sess, X, Y)
saver.save(sess, r"E:\tf_project\練習\model_save_dir\my-model", global_step=training_steps)
coord.request_stop()
coord.join(threads)
sess.close()
【2】模型重新加載
1、加載時間最近的模型,使用ckpt.model_checkpoint_path
with tf.Session() as sess:
X, Y = inputs()
total_loss = loss(X, Y)
train_op = train(total_loss)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
initial_step = 0
training_steps = 30000
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state(os.path.dirname(r"E:\tf_project\練習\model_save_dir\my-model"))
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
print(ckpt.model_checkpoint_path)
initial_step = int(ckpt.model_checkpoint_path.rsplit('-', 1)[1])
for step in range(initial_step, training_steps):
sess.run(train_op)
if step % 10 == 0:
print("loss", sess.run(total_loss))
if step % 1000 == 0:
saver.save(sess, r"E:\tf_project\練習\model_save_dir1\my-model", global_step=step)
evaluate(sess, X, Y)
saver.save(sess, r"E:\tf_project\練習\model_save_dir1\my-model", global_step=training_steps)
coord.request_stop()
coord.join(threads)
sess.close()
output:
E:\tf_project\練習\model_save_dir\my-model-9000-20000
loss 5214449.5
loss 5214338.0
loss 5214226.0
loss 5214114.0
.
.
.
loss 5106910.0
loss 5106805.5
loss 5106701.0
loss 5106597.0
[[319.9712]]
[[270.7156]]
2、從時間最近的幾個模型中選取一個或者多個模型加載
使用 ckpt.all_model_checkpoint_paths
with tf.Session() as sess:
X, Y = inputs()
total_loss = loss(X, Y)
train_op = train(total_loss)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
init = tf.global_variables_initializer()
sess.run(init)
training_steps = 30000
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state(os.path.dirname(r"E:\tf_project\練習\model_save_dir\my-model"))
path = ckpt.all_model_checkpoint_paths[1]
print(ckpt.all_model_checkpoint_paths)
if ckpt and path:
saver.restore(sess, path)
initial_step = int(ckpt.model_checkpoint_path.rsplit('-', 1)[1])
for step in range(training_steps):
sess.run(train_op)
if step % 10 == 0:
print("loss", sess.run(total_loss))
if step % 1000 == 0:
saver.save(sess, r"E:\tf_project\練習\model_save_dir1\my-model", global_step=step)
evaluate(sess, X, Y)
saver.save(sess, r"E:\tf_project\練習\model_save_dir1\my-model", global_step=training_steps)
coord.request_stop()
coord.join(threads)
sess.close()
output:
['E:\\tf_project\\練習\\model_save_dir\\my-model-9000-16000',
'E:\\tf_project\\練習\\model_save_dir\\my-model-9000-17000',
'E:\\tf_project\\練習\\model_save_dir\\my-model-9000-18000',
'E:\\tf_project\\練習\\model_save_dir\\my-model-9000-19000',
'E:\\tf_project\\練習\\model_save_dir\\my-model-9000-20000']
loss 5217809.0
loss 5217696.5
loss 5217585.0
loss 5217473.0
loss 5217361.0
.
.
.
loss 5110039.5
loss 5109935.5
loss 5109830.5
loss 5109727.0
[[319.98105]]
[[270.67313]]
3、使用結構圖加載 tf.train.import_meta_graph
with tf.Session() as sess:
X, Y = inputs()
total_loss = loss(X, Y)
train_op = train(total_loss)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
init = tf.global_variables_initializer()
sess.run(init)
training_steps = 10000
saver = tf.train.import_meta_graph(r"E:\tf_project\練習\model_save_dir\my-model-20000.meta")
saver.restore(sess, tf.train.latest_checkpoint(r"E:\tf_project\練習\model_save_dir"))
print(tf.train.latest_checkpoint(r"E:\tf_project\練習\model_save_dir"))
for step in range(training_steps):
sess.run(train_op)
if step % 10 == 0:
print("loss", sess.run(total_loss))
if step % 1000 == 0:
saver.save(sess, r"E:\tf_project\練習\model_save_dir1\my-model", global_step=step)
evaluate(sess, X, Y)
saver.save(sess, r"E:\tf_project\練習\model_save_dir1\my-model", global_step=training_steps)
coord.request_stop()
coord.join(threads)
sess.close()
output:
E:\tf_project\練習\model_save_dir\my-model-9000-20000
loss 7608772.0
loss 5352849.5
loss 5350043.5
loss 5347919.0
loss 5346300.5
.
.
.
loss 5226120.5
loss 5226008.0
loss 5225895.5
loss 5225782.0
[[320.33838]]
[[269.12772]]
4、通用加載方式,使用 saver.restore
這裏可以指定從哪一個模型進行加載
with tf.Session() as sess:
CHECKPOINT_PATH = r"E:\tf_project\練習\model_save_dir\my-model-9000"
saver = tf.train.Saver()
saver.restore(sess, CHECKPOINT_PATH)
print(CHECKPOINT_PATH)
X, Y = inputs()
total_loss = loss(X, Y)
train_op = train(total_loss)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
initial_step = 0
training_steps = 10000
for step in range(initial_step, training_steps):
sess.run(train_op)
if step % 10 == 0:
print("loss", sess.run(total_loss))
if step % 1000 == 0:
saver.save(sess, CHECKPOINT_PATH, global_step=step)
evaluate(sess, X, Y)
saver.save(sess, CHECKPOINT_PATH, global_step=training_steps)
coord.request_stop()
coord.join(threads)
sess.close()
output:
E:\tf_project\練習\model_save_dir\my-model-9000
loss 5236970.0
loss 5236958.5
loss 5236947.5
loss 5236935.5
.
.
.
loss 5225708.5
loss 5225696.5
loss 5225685.5
[[320.33884]]
[[269.1281]]