筆記 - 模型訓練:保存讀取使用模型

保存模型

# 創建Saver()節點
saver = tf.train.Saver()

# 訓練過程中保存節點
save_path = saver.save(sess, "./ckpt/my_model.ckpt", global_step=epoch)

# 保存最終節點
save_path = saver.save(sess, "./ckpt/my_model_final.ckpt")

讀取模型

# 創建Saver()節點
saver = tf.train.Saver()

# 讀取節點
ckpt = tf.train.get_checkpoint_state('./ckpt/')

# 讀取模型
saver.restore(sess, ckpt.model_checkpoint_path)

使用模型

"""
前期需要將整個計算圖構建出來
但不需要像訓練時init參數
"""

saver = tf.train.Saver()

with tf.Session() as sess:
    saver.restore(sess, "./ckpt/my_model_final.ckpt")

    # 測試
    print(accuracy.eval({x: my_mnist.test.images, y_: my_mnist.test.labels}))

TODO:restore的時候,參數是如何對應到網絡結構上的

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章