爲了讓訓練結果可以複用,需要將訓練得到的神經網絡模型持久化。Tensorflow提供了一個簡單的API來保存和還原一個神經網絡模型。這個API是tf.train.Saver類。
用tf.train.Saver類保存tensorflow計算圖的代碼:
import tensorflow as tf
#聲明兩個變量求和
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='v2')
result = v1 + v2
init_op = tf.global_variables_initializer()
#聲明tf.train.Saver類用於保存模型
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init_op)
#指定模型保存路徑
saver.save(sess, '/path/to/model/model.ckpt')
代碼運行完會出現四個文件,如下圖:
最後一個文件model.ckpt.meta保存了計算圖的結構。
加載已經保存的tensorflow模型代碼:
import tensorflow as tf
#聲明兩個變量求和
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='v2')
result = v1 + v2
saver = tf.train.Saver()
with tf.Session() as sess:
#加載已經保存的模型,並通過已經保存的模型中變量的值來計算加法
saver.restore(sess, "/path/to/model/model.ckpt")
print(sess.run(result))
和上面代碼不同的是沒有變量的初始化過程,而是將變量的值通過已經保存的模型加載進來。代碼中重複定義了計算圖,也可以省略直接加載已經持久化的圖。如下:
import tensorflow as tf
#直接加載持久化的圖
saver = tf.train.import_meta_graph(
"/path/to/model/model.ckpt.meta"
)
with tf.Session() as sess:
saver.restore(sess, "/path/to/model/model.ckpt")
#通過張量的名稱來獲取張量
print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0")))
保存和加載部分變量
在上面的程序中,默認保存和加載了tensorflow計算圖上的全部變量。但有時需要保存和加載部分變量,這是可以通過列表來指定需要保存或加載的變量。例如在上面的程序中寫saver = tf.train.Saver([v1]),那麼只有變量v1會被加載進來。
保存或者加載時給變量重命名
tensorflow可以通過字典將模型保存時的變量名和需要加載的變量聯繫起來。例子如下:
#這裏聲明的變量和已經保存的模型中的變量不同
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='other-v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='other-v2')
#字典指定了原來名稱爲v1的變量現在加載到變量v1中,other-v1v2 名稱爲,一樣
saver = tf.train.Saver({"v1":v1, "v2":v2})
這樣做可以方便使用變量的滑動平均值。