保存模型
# 創建Saver()節點
saver = tf.train.Saver()
# 訓練過程中保存節點
save_path = saver.save(sess, "./ckpt/my_model.ckpt", global_step=epoch)
# 保存最終節點
save_path = saver.save(sess, "./ckpt/my_model_final.ckpt")
讀取模型
# 創建Saver()節點
saver = tf.train.Saver()
# 讀取節點
ckpt = tf.train.get_checkpoint_state('./ckpt/')
# 讀取模型
saver.restore(sess, ckpt.model_checkpoint_path)
使用模型
"""
前期需要將整個計算圖構建出來
但不需要像訓練時init參數
"""
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, "./ckpt/my_model_final.ckpt")
# 測試
print(accuracy.eval({x: my_mnist.test.images, y_: my_mnist.test.labels}))
TODO:restore的時候,參數是如何對應到網絡結構上的