Tensorflow模型代码持久化实现

为了让训练结果可以复用,需要将训练得到的神经网络模型持久化。Tensorflow提供了一个简单的API来保存和还原一个神经网络模型。这个API是tf.train.Saver类。

用tf.train.Saver类保存tensorflow计算图的代码:

import tensorflow as tf

#声明两个变量求和
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='v2')
result = v1 + v2

init_op = tf.global_variables_initializer()
#声明tf.train.Saver类用于保存模型
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)
    #指定模型保存路径
    saver.save(sess, '/path/to/model/model.ckpt')

代码运行完会出现四个文件,如下图:

最后一个文件model.ckpt.meta保存了计算图的结构。

加载已经保存的tensorflow模型代码:

import tensorflow as tf

#声明两个变量求和
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='v2')
result = v1 + v2

saver = tf.train.Saver()

with tf.Session() as sess:
    #加载已经保存的模型,并通过已经保存的模型中变量的值来计算加法
    saver.restore(sess, "/path/to/model/model.ckpt")
    print(sess.run(result))

和上面代码不同的是没有变量的初始化过程,而是将变量的值通过已经保存的模型加载进来。代码中重复定义了计算图,也可以省略直接加载已经持久化的图。如下:

import tensorflow as tf
#直接加载持久化的图

saver = tf.train.import_meta_graph(
    "/path/to/model/model.ckpt.meta"
)
with tf.Session() as sess:
    saver.restore(sess, "/path/to/model/model.ckpt")
    #通过张量的名称来获取张量
    print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0")))

保存和加载部分变量

    在上面的程序中,默认保存和加载了tensorflow计算图上的全部变量。但有时需要保存和加载部分变量,这是可以通过列表来指定需要保存或加载的变量。例如在上面的程序中写saver = tf.train.Saver([v1]),那么只有变量v1会被加载进来。

保存或者加载时给变量重命名

tensorflow可以通过字典将模型保存时的变量名和需要加载的变量联系起来。例子如下:

#这里声明的变量和已经保存的模型中的变量不同
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='other-v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='other-v2')
#字典指定了原来名称为v1的变量现在加载到变量v1中,other-v1v2 名称为,一样
saver = tf.train.Saver({"v1":v1, "v2":v2})

这样做可以方便使用变量的滑动平均值。

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章