tensorflow读取一个模型后多次使用

训练好一个模型后,将其投入使用,会有在项目初始化后多次加载测试数据的需求,可以采用保存graph的思想实现

(在一个项目中需要加载多个模型同样可用)

另:这条博客接我的上一条https://blog.csdn.net/qq_34470213/article/details/104076898,是在上一个代码的基础上改写的。

1、新建文件test.py,建一个类Model_test,用来保存模型,包括一个初始化方法,用来初始化模型(项目中仅需初始化时调用一次),一个测试调用方法,用来调用模型进行测试(每次测试调用一次)。

class Model_test():
    def restore(self):
        self.model = Model.LeNet5(1, 5)
        path = "D:/model/model/model.ckpt"
        self.model.load(path)

    def restore_test(self, image_path):
        image = Process.process_one(image_path)
        sort = self.model.test1(image)
        return sort

2、在model.py的类中添加初始化函数和测试函数,这里和之前的测试函数的差别在于拆分开了加载和测试的部分,并且将graph和session保存为了类属性变量。

    def load(self, model_path):
        self.graph = tf.Graph()
        self.sess = tf.Session(graph=self.graph)
        with self.sess.graph.as_default():
            self.sess.run(tf.global_variables_initializer())
            self.sess.run(tf.local_variables_initializer())
            saver11 = tf.train.import_meta_graph(model_path+'.meta',
                                               clear_devices=True)
            saver11.restore(self.sess, model_path)

    def test1(self, image):
        x = tf.placeholder(tf.float32, [None, 64, 64, 1], name='x-input')
        self.activation = self.graph.get_tensor_by_name('layer6-fc2/add:0')
        image = np.array(image) / 255.0
        image = np.reshape(image, (-1, 64, 64, 1))
        logit = tf.arg_max(self.activation, 1)
        y, label = self.sess.run((self.activation, logit), feed_dict={'x-input:0': image})
        return label

3、以上两步就可以成功实现了,调用方法为:

tm = test.Model_test()
tm.restore()


……


while(True){
    sort = tm.restore_test(pathname[i])
}

……

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章