Tensorflow保存和重载参数
参考链接:
https://www.cnblogs.com/houkai/p/9723988.html
https://blog.csdn.net/LordofRobots/article/details/77719020
前言:
虽然现在一直在用TensorFlow,但总有一种49年入国军的感觉。
可最近还必须要用这个,且经常会出现一些稀奇古怪的bug,因此不得不查很多相关的问题。
bug描述:
在VGG16训练好了之后,调用save函数保存,然后需要用的时候重载VGG16,调用restore函数,出现的结果很奇异——不报错,但是重新调用的VGG16预测精度,比训练的过程中,测试集精度低很多!
二者的代码如下:
def save(self, path):
saver = tf.train.Saver()
saver.save(self.sess, save_path=path+'/params', write_meta_graph=False)
def restore(self, path):
saver = tf.train.Saver()
saver.restore(self.sess, tf.train.latest_checkpoint(path))
print("restore model successful")
bug解决过程:
- 首先通过对输入图像的格式进行排查,验证两次模型加载的输入图像的格式:RGB通道,是否归一化,size等都进行验证。
发现完全一致!(其实调用的函数都是一致的) - 那就只能看是否是因为模型保存错误,或者我当初的测试集效果好,由于用错了,训练集的数据。
- 排除实验:在重新加载的模型中输入训练集的数据,发现效果仍然还是很差!
- 那就只能是模型保存错误了!
bug解决过程:
其实到现在为止,我还是不清楚为什么会出现这个bug,也不知道是如何解决的~
因为我动了两个变量,一个是修改了上面原先的函数,同时加了一个保存每个epoch的参数的功能。
修改上面的函数如下:
def save(self, path):
saver = tf.train.Saver()
saver.save(self.sess, save_path=path+'/params')
def restore(self, path):
saver = tf.train.Saver()
saver.restore(self.sess, tf.train.latest_checkpoint(path))
print("restore model successful")
其次额外加了一个这样的功能:
但是下面的保存可以,但是读取的时候,报错说format格式不对,可能也是因为我同时用了两种方式保存。
但是这时候,第一种保存和重载的方式竟然没有精度误差了!
就很有意思~
def save_weight(self, mode, path, sess=None):
assert(mode in ['latest', 'best'])
if sess is None:
sess_ = self.sess
else:
sess_ = sess
saver = self.saver if mode == 'latest' else self.best_saver
saver.save(sess_, path, global_step=self.learning_step)
print('save', mode, 'model in', path, 'successfully')
def load_weight(self, mode, path, sess=None):
assert (mode in ['latest', 'best'])
if sess is None:
sess_ = self.sess
else:
sess_ = sess
saver = self.saver if mode == 'latest' else self.best_saver
ckpt = tf.train.get_checkpoint_state(path)
print("ckpt:", ckpt)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess_, path)
print('load', mode, 'model in', path, 'successfully')
else:
raise FileNotFoundError('Not Found Model File!')
接下来的探索:
只修改原先的函数,也只调用一次参数保存。
看看预测精度是否一致,如果一致,那么应该主要原因在于那个False.