Tensorflow模型持久化

Tensorflow模型持久化

1. 保存兩個變量和的模型

import tensorflow as tf
v1 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
v2 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
result = v1 + v2

init_op = tf.global_variables_initializer()
#聲明tf.train.Saver類用於保存模型
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)
    saver.save(sess, "Saved_model/model.ckpt")  

模型保存(Tensorflow會將計算圖的結構和圖上參數取值分別保存):

  • model.ckpt.meta: 保存Tensorflow計算圖的結構。
  • model.ckpt: 保存Tensorflow程序中每一個變量的取值。
  • checkpoint: 保存目錄下所有的模型文件列表。

2. 加載保存的模型

v1 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
v2 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
result = v1 + v2

init_op = tf.global_variables_initializer()
#聲明tf.train.Saver類用於保存模型
saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, "Saved_model/model.ckpt")
    print sess.run(result)  

注:加載模型的代碼和保存模型的代碼基本一樣。
    兩端代碼中,唯一不同的是,在加載模型的代碼中沒有運行變量的初始話過程,而是將變量的值通過已經保存的模型加載進來。

3. 直接加載持久化的圖

saver = tf.train.import_meta_graph("Saved_model/model.ckpt.meta")
v3 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))

with tf.Session() as sess:
    saver.restore(sess, "Saved_model/model.ckpt")
    print sess.run(v1) 
    print sess.run(v2) 
    print sess.run(v3) 
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章