Tensorflow保存和重載參數
參考鏈接:
https://www.cnblogs.com/houkai/p/9723988.html
https://blog.csdn.net/LordofRobots/article/details/77719020
前言:
雖然現在一直在用TensorFlow,但總有一種49年入國軍的感覺。
可最近還必須要用這個,且經常會出現一些稀奇古怪的bug,因此不得不查很多相關的問題。
bug描述:
在VGG16訓練好了之後,調用save函數保存,然後需要用的時候重載VGG16,調用restore函數,出現的結果很奇異——不報錯,但是重新調用的VGG16預測精度,比訓練的過程中,測試集精度低很多!
二者的代碼如下:
def save(self, path):
saver = tf.train.Saver()
saver.save(self.sess, save_path=path+'/params', write_meta_graph=False)
def restore(self, path):
saver = tf.train.Saver()
saver.restore(self.sess, tf.train.latest_checkpoint(path))
print("restore model successful")
bug解決過程:
- 首先通過對輸入圖像的格式進行排查,驗證兩次模型加載的輸入圖像的格式:RGB通道,是否歸一化,size等都進行驗證。
發現完全一致!(其實調用的函數都是一致的) - 那就只能看是否是因爲模型保存錯誤,或者我當初的測試集效果好,由於用錯了,訓練集的數據。
- 排除實驗:在重新加載的模型中輸入訓練集的數據,發現效果仍然還是很差!
- 那就只能是模型保存錯誤了!
bug解決過程:
其實到現在爲止,我還是不清楚爲什麼會出現這個bug,也不知道是如何解決的~
因爲我動了兩個變量,一個是修改了上面原先的函數,同時加了一個保存每個epoch的參數的功能。
修改上面的函數如下:
def save(self, path):
saver = tf.train.Saver()
saver.save(self.sess, save_path=path+'/params')
def restore(self, path):
saver = tf.train.Saver()
saver.restore(self.sess, tf.train.latest_checkpoint(path))
print("restore model successful")
其次額外加了一個這樣的功能:
但是下面的保存可以,但是讀取的時候,報錯說format格式不對,可能也是因爲我同時用了兩種方式保存。
但是這時候,第一種保存和重載的方式竟然沒有精度誤差了!
就很有意思~
def save_weight(self, mode, path, sess=None):
assert(mode in ['latest', 'best'])
if sess is None:
sess_ = self.sess
else:
sess_ = sess
saver = self.saver if mode == 'latest' else self.best_saver
saver.save(sess_, path, global_step=self.learning_step)
print('save', mode, 'model in', path, 'successfully')
def load_weight(self, mode, path, sess=None):
assert (mode in ['latest', 'best'])
if sess is None:
sess_ = self.sess
else:
sess_ = sess
saver = self.saver if mode == 'latest' else self.best_saver
ckpt = tf.train.get_checkpoint_state(path)
print("ckpt:", ckpt)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess_, path)
print('load', mode, 'model in', path, 'successfully')
else:
raise FileNotFoundError('Not Found Model File!')
接下來的探索:
只修改原先的函數,也只調用一次參數保存。
看看預測精度是否一致,如果一致,那麼應該主要原因在於那個False.