Tensorflow保存和重載參數

Tensorflow保存和重載參數

參考鏈接:

https://www.cnblogs.com/houkai/p/9723988.html
https://blog.csdn.net/LordofRobots/article/details/77719020

前言:

雖然現在一直在用TensorFlow,但總有一種49年入國軍的感覺。
可最近還必須要用這個,且經常會出現一些稀奇古怪的bug,因此不得不查很多相關的問題。

bug描述:

在VGG16訓練好了之後,調用save函數保存,然後需要用的時候重載VGG16,調用restore函數,出現的結果很奇異——不報錯,但是重新調用的VGG16預測精度,比訓練的過程中,測試集精度低很多!

二者的代碼如下:

    def save(self, path):
        saver = tf.train.Saver()
        saver.save(self.sess, save_path=path+'/params', write_meta_graph=False)

    def restore(self, path):
        saver = tf.train.Saver()
        saver.restore(self.sess, tf.train.latest_checkpoint(path))
        print("restore model successful")

bug解決過程:

  1. 首先通過對輸入圖像的格式進行排查,驗證兩次模型加載的輸入圖像的格式:RGB通道,是否歸一化,size等都進行驗證。
    發現完全一致!(其實調用的函數都是一致的)
  2. 那就只能看是否是因爲模型保存錯誤,或者我當初的測試集效果好,由於用錯了,訓練集的數據。
  3. 排除實驗:在重新加載的模型中輸入訓練集的數據,發現效果仍然還是很差!
  4. 那就只能是模型保存錯誤了!

bug解決過程:

其實到現在爲止,我還是不清楚爲什麼會出現這個bug,也不知道是如何解決的~

因爲我動了兩個變量,一個是修改了上面原先的函數,同時加了一個保存每個epoch的參數的功能。

修改上面的函數如下:

  def save(self, path):
        saver = tf.train.Saver()
        saver.save(self.sess, save_path=path+'/params')

    def restore(self, path):
        saver = tf.train.Saver()
        saver.restore(self.sess, tf.train.latest_checkpoint(path))
        print("restore model successful")

其次額外加了一個這樣的功能:
但是下面的保存可以,但是讀取的時候,報錯說format格式不對,可能也是因爲我同時用了兩種方式保存。
但是這時候,第一種保存和重載的方式竟然沒有精度誤差了!
就很有意思~

    def save_weight(self, mode, path, sess=None):
        assert(mode in ['latest', 'best'])
        if sess is None:
            sess_ = self.sess
        else:
            sess_ = sess
        saver = self.saver if mode == 'latest' else self.best_saver
        saver.save(sess_, path, global_step=self.learning_step)
        print('save', mode, 'model in', path, 'successfully')

    def load_weight(self, mode, path, sess=None):
        assert (mode in ['latest', 'best'])
        if sess is None:
            sess_ = self.sess
        else:
            sess_ = sess
        saver = self.saver if mode == 'latest' else self.best_saver
        ckpt = tf.train.get_checkpoint_state(path)
        print("ckpt:", ckpt)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess_, path)
            print('load', mode, 'model in', path, 'successfully')
        else:
            raise FileNotFoundError('Not Found Model File!')

接下來的探索:

只修改原先的函數,也只調用一次參數保存。

看看預測精度是否一致,如果一致,那麼應該主要原因在於那個False.

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章