Tensorflow模型代碼持久化實現

爲了讓訓練結果可以複用,需要將訓練得到的神經網絡模型持久化。Tensorflow提供了一個簡單的API來保存和還原一個神經網絡模型。這個API是tf.train.Saver類。

用tf.train.Saver類保存tensorflow計算圖的代碼:

import tensorflow as tf

#聲明兩個變量求和
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='v2')
result = v1 + v2

init_op = tf.global_variables_initializer()
#聲明tf.train.Saver類用於保存模型
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)
    #指定模型保存路徑
    saver.save(sess, '/path/to/model/model.ckpt')

代碼運行完會出現四個文件,如下圖:

最後一個文件model.ckpt.meta保存了計算圖的結構。

加載已經保存的tensorflow模型代碼:

import tensorflow as tf

#聲明兩個變量求和
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='v2')
result = v1 + v2

saver = tf.train.Saver()

with tf.Session() as sess:
    #加載已經保存的模型,並通過已經保存的模型中變量的值來計算加法
    saver.restore(sess, "/path/to/model/model.ckpt")
    print(sess.run(result))

和上面代碼不同的是沒有變量的初始化過程,而是將變量的值通過已經保存的模型加載進來。代碼中重複定義了計算圖,也可以省略直接加載已經持久化的圖。如下:

import tensorflow as tf
#直接加載持久化的圖

saver = tf.train.import_meta_graph(
    "/path/to/model/model.ckpt.meta"
)
with tf.Session() as sess:
    saver.restore(sess, "/path/to/model/model.ckpt")
    #通過張量的名稱來獲取張量
    print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0")))

保存和加載部分變量

    在上面的程序中,默認保存和加載了tensorflow計算圖上的全部變量。但有時需要保存和加載部分變量,這是可以通過列表來指定需要保存或加載的變量。例如在上面的程序中寫saver = tf.train.Saver([v1]),那麼只有變量v1會被加載進來。

保存或者加載時給變量重命名

tensorflow可以通過字典將模型保存時的變量名和需要加載的變量聯繫起來。例子如下:

#這裏聲明的變量和已經保存的模型中的變量不同
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='other-v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='other-v2')
#字典指定了原來名稱爲v1的變量現在加載到變量v1中,other-v1v2 名稱爲,一樣
saver = tf.train.Saver({"v1":v1, "v2":v2})

這樣做可以方便使用變量的滑動平均值。

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章