tensorflow模型保存與複用多種方式

 抄襲一段:checkpoint是一個內部事件,該事件激活後會觸發數據庫寫進程將數據緩衝中的髒數據寫到數據文件中。

checkpoint主要2個作用:

 保證數據庫的一致性

縮短實例恢復時間

通俗的講,checkpoint像word的自動保存一樣。

tensorflow模型包含  meta圖(網絡結構圖) 和 checkpoint文件(網絡結構裏的參數值,現已經被分拆爲3個文件)

即總的文件包含目錄爲:

model.data-00000-of-00001  保存變量值

model.index                           保存 .data 和 .meta 文件對應關係

model.meta                           結構圖

checkpoint                            文本文件,記錄中間節點上保存的模型的名稱

import tensorflow as tf
import os


W = tf.Variable(tf.zeros([2, 1]), name="weights")
b = tf.Variable(0., name="bias")


def inference(X):
    return tf.matmul(X, W) + b


def loss(X, Y):
    Y_predicted = inference(X)
    return tf.reduce_sum(tf.squared_difference(Y, Y_predicted))


def inputs():
    weight_age = [[84, 46], [73, 20], [65, 52], [70, 30], [76, 57], [69, 25], [63, 28], [72, 36], [79, 57], [75, 44],
                  [27, 24], [89, 31], [65, 52], [57, 23], [59, 60], [69, 48], [60, 34], [79, 51], [75, 50], [82, 34],
                  [59, 46], [67, 23], [85, 37], [55, 40], [63, 30]]
    blood_fat_content = [354, 190, 405, 263, 451, 302, 288, 385, 402, 365, 209, 290, 346, 254, 395, 434, 220, 374, 308,
                         220, 311, 181, 274, 303, 244]
    return tf.to_float(weight_age), tf.to_float(blood_fat_content)


def train(total_loss):
    learning_rate = 0.0000001
    return tf.train.GradientDescentOptimizer(learning_rate).minimize(total_loss)


def evaluate(sess, X, Y):
    print(sess.run(inference([[80., 25.]])))  #303
    print(sess.run(inference([[65., 25.]])))  #256

【1】模型訓練:

with tf.Session() as sess:
    X, Y = inputs()
    #init = tf.global_variables_initializer()
    total_loss = loss(X, Y)
    train_op = train(total_loss)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    init = tf.global_variables_initializer()
    sess.run(init)
    training_steps = 10000
    saver = tf.train.Saver()
    for step in range(training_steps):
        sess.run(train_op)
        if step % 10 == 0:
            print("loss", sess.run(total_loss))

        if step % 1000 == 0:
            saver.save(sess, r"E:\tf_project\練習\model_save_dir\my-model", global_step=step)

    evaluate(sess, X, Y)
    saver.save(sess, r"E:\tf_project\練習\model_save_dir\my-model", global_step=training_steps)

    coord.request_stop()
    coord.join(threads)
    sess.close()

【2】模型重新加載

1、加載時間最近的模型,使用ckpt.model_checkpoint_path

with tf.Session() as sess:
    X, Y = inputs()
    total_loss = loss(X, Y)
    train_op = train(total_loss)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    initial_step = 0
    training_steps = 30000
    saver = tf.train.Saver()
    ckpt = tf.train.get_checkpoint_state(os.path.dirname(r"E:\tf_project\練習\model_save_dir\my-model"))

    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print(ckpt.model_checkpoint_path)
        initial_step = int(ckpt.model_checkpoint_path.rsplit('-', 1)[1])

    for step in range(initial_step, training_steps):
        sess.run(train_op)
        if step % 10 == 0:
            print("loss", sess.run(total_loss))

        if step % 1000 == 0:
            saver.save(sess, r"E:\tf_project\練習\model_save_dir1\my-model", global_step=step)

    evaluate(sess, X, Y)
    saver.save(sess, r"E:\tf_project\練習\model_save_dir1\my-model", global_step=training_steps)

    coord.request_stop()
    coord.join(threads)
    sess.close()



output:

E:\tf_project\練習\model_save_dir\my-model-9000-20000
loss 5214449.5
loss 5214338.0
loss 5214226.0
loss 5214114.0
.
.
.
loss 5106910.0
loss 5106805.5
loss 5106701.0
loss 5106597.0
[[319.9712]]
[[270.7156]]

2、從時間最近的幾個模型中選取一個或者多個模型加載

使用 ckpt.all_model_checkpoint_paths


with tf.Session() as sess:
    X, Y = inputs()
    total_loss = loss(X, Y)
    train_op = train(total_loss)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    init = tf.global_variables_initializer()
    sess.run(init)
    training_steps = 30000
    saver = tf.train.Saver()

    ckpt = tf.train.get_checkpoint_state(os.path.dirname(r"E:\tf_project\練習\model_save_dir\my-model"))
    path = ckpt.all_model_checkpoint_paths[1]
    print(ckpt.all_model_checkpoint_paths)
    if ckpt and path:
        saver.restore(sess, path)
        initial_step = int(ckpt.model_checkpoint_path.rsplit('-', 1)[1])

    for step in range(training_steps):
        sess.run(train_op)
        if step % 10 == 0:
            print("loss", sess.run(total_loss))

        if step % 1000 == 0:
            saver.save(sess, r"E:\tf_project\練習\model_save_dir1\my-model", global_step=step)

    evaluate(sess, X, Y)
    saver.save(sess, r"E:\tf_project\練習\model_save_dir1\my-model", global_step=training_steps)

    coord.request_stop()
    coord.join(threads)
    sess.close()

output:

['E:\\tf_project\\練習\\model_save_dir\\my-model-9000-16000',
 'E:\\tf_project\\練習\\model_save_dir\\my-model-9000-17000',
 'E:\\tf_project\\練習\\model_save_dir\\my-model-9000-18000', 
 'E:\\tf_project\\練習\\model_save_dir\\my-model-9000-19000',
 'E:\\tf_project\\練習\\model_save_dir\\my-model-9000-20000']
loss 5217809.0
loss 5217696.5
loss 5217585.0
loss 5217473.0
loss 5217361.0
.
.
.
loss 5110039.5
loss 5109935.5
loss 5109830.5
loss 5109727.0
[[319.98105]]
[[270.67313]]

3、使用結構圖加載 tf.train.import_meta_graph

with tf.Session() as sess:
    X, Y = inputs()
    total_loss = loss(X, Y)
    train_op = train(total_loss)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    init = tf.global_variables_initializer()
    sess.run(init)
    training_steps = 10000

    saver = tf.train.import_meta_graph(r"E:\tf_project\練習\model_save_dir\my-model-20000.meta")
    saver.restore(sess, tf.train.latest_checkpoint(r"E:\tf_project\練習\model_save_dir"))
    print(tf.train.latest_checkpoint(r"E:\tf_project\練習\model_save_dir"))

    for step in range(training_steps):
        sess.run(train_op)
        if step % 10 == 0:
            print("loss", sess.run(total_loss))

        if step % 1000 == 0:
            saver.save(sess, r"E:\tf_project\練習\model_save_dir1\my-model", global_step=step)

    evaluate(sess, X, Y)
    saver.save(sess, r"E:\tf_project\練習\model_save_dir1\my-model", global_step=training_steps)

    coord.request_stop()
    coord.join(threads)
    sess.close()

output:

E:\tf_project\練習\model_save_dir\my-model-9000-20000
loss 7608772.0
loss 5352849.5
loss 5350043.5
loss 5347919.0
loss 5346300.5
.
.
.
loss 5226120.5
loss 5226008.0
loss 5225895.5
loss 5225782.0
[[320.33838]]
[[269.12772]]

4、通用加載方式,使用 saver.restore

這裏可以指定從哪一個模型進行加載

with tf.Session() as sess:
    CHECKPOINT_PATH = r"E:\tf_project\練習\model_save_dir\my-model-9000"
    saver = tf.train.Saver()
    saver.restore(sess, CHECKPOINT_PATH)
    print(CHECKPOINT_PATH)
    X, Y = inputs()
    total_loss = loss(X, Y)
    train_op = train(total_loss)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    initial_step = 0
    training_steps = 10000

    for step in range(initial_step, training_steps):
        sess.run(train_op)
        if step % 10 == 0:
            print("loss", sess.run(total_loss))

        if step % 1000 == 0:
            saver.save(sess, CHECKPOINT_PATH, global_step=step)

    evaluate(sess, X, Y)
    saver.save(sess, CHECKPOINT_PATH, global_step=training_steps)

    coord.request_stop()
    coord.join(threads)
    sess.close()

output:

E:\tf_project\練習\model_save_dir\my-model-9000
loss 5236970.0
loss 5236958.5
loss 5236947.5
loss 5236935.5
.
.
.
loss 5225708.5
loss 5225696.5
loss 5225685.5
[[320.33884]]
[[269.1281]]

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章