Tensorflow保存和重载参数

Tensorflow保存和重载参数

参考链接:

https://www.cnblogs.com/houkai/p/9723988.html
https://blog.csdn.net/LordofRobots/article/details/77719020

前言:

虽然现在一直在用TensorFlow,但总有一种49年入国军的感觉。
可最近还必须要用这个,且经常会出现一些稀奇古怪的bug,因此不得不查很多相关的问题。

bug描述:

在VGG16训练好了之后,调用save函数保存,然后需要用的时候重载VGG16,调用restore函数,出现的结果很奇异——不报错,但是重新调用的VGG16预测精度,比训练的过程中,测试集精度低很多!

二者的代码如下:

    def save(self, path):
        saver = tf.train.Saver()
        saver.save(self.sess, save_path=path+'/params', write_meta_graph=False)

    def restore(self, path):
        saver = tf.train.Saver()
        saver.restore(self.sess, tf.train.latest_checkpoint(path))
        print("restore model successful")

bug解决过程:

  1. 首先通过对输入图像的格式进行排查,验证两次模型加载的输入图像的格式:RGB通道,是否归一化,size等都进行验证。
    发现完全一致!(其实调用的函数都是一致的)
  2. 那就只能看是否是因为模型保存错误,或者我当初的测试集效果好,由于用错了,训练集的数据。
  3. 排除实验:在重新加载的模型中输入训练集的数据,发现效果仍然还是很差!
  4. 那就只能是模型保存错误了!

bug解决过程:

其实到现在为止,我还是不清楚为什么会出现这个bug,也不知道是如何解决的~

因为我动了两个变量,一个是修改了上面原先的函数,同时加了一个保存每个epoch的参数的功能。

修改上面的函数如下:

  def save(self, path):
        saver = tf.train.Saver()
        saver.save(self.sess, save_path=path+'/params')

    def restore(self, path):
        saver = tf.train.Saver()
        saver.restore(self.sess, tf.train.latest_checkpoint(path))
        print("restore model successful")

其次额外加了一个这样的功能:
但是下面的保存可以,但是读取的时候,报错说format格式不对,可能也是因为我同时用了两种方式保存。
但是这时候,第一种保存和重载的方式竟然没有精度误差了!
就很有意思~

    def save_weight(self, mode, path, sess=None):
        assert(mode in ['latest', 'best'])
        if sess is None:
            sess_ = self.sess
        else:
            sess_ = sess
        saver = self.saver if mode == 'latest' else self.best_saver
        saver.save(sess_, path, global_step=self.learning_step)
        print('save', mode, 'model in', path, 'successfully')

    def load_weight(self, mode, path, sess=None):
        assert (mode in ['latest', 'best'])
        if sess is None:
            sess_ = self.sess
        else:
            sess_ = sess
        saver = self.saver if mode == 'latest' else self.best_saver
        ckpt = tf.train.get_checkpoint_state(path)
        print("ckpt:", ckpt)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess_, path)
            print('load', mode, 'model in', path, 'successfully')
        else:
            raise FileNotFoundError('Not Found Model File!')

接下来的探索:

只修改原先的函数,也只调用一次参数保存。

看看预测精度是否一致,如果一致,那么应该主要原因在于那个False.

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章