使用tensorflow版本爲1.12
在訓練過程中保存模型
def save(sess, save_path):
"""save_path is a file path"""
self.saver.save(sess, save_path=save_path)
預測或updating模型時,加載模型
def restore(model_save_path):
if os.path.isdir(model_save_path):
# 根據目錄名稱來獲取最新的模型的文件路徑
model_file_path = tf.train.latest_checkpoint(checkpoint_dir=dir_name)
else:
model_file_path = model_save_path
saver = tf.train.saver()
saver.restore(sess, save_path=mdoel_file_path)
加載最新訓練的模型的方法具體的流程:
checkpoint_dir中有一個checkpoint文件,裏面記錄了所有保存的模型的路徑和model_checkpoint_path
變量。其中model_checkpoint_path
記錄了最新的一個模型的路徑,調用tf.train.latest_checkpoint()
函數,得到的就是這個路徑。
總結
- tensorflow 使用
tf.train.saver()
的save()
和restore()
函數來保存和加載模型 save
和restore
函數需要兩個參數一個是session,另一個模型的文件路徑(注意不目錄)
接下來的問題是:
- saver在什麼時候初始化?
當保存模型時,saver在初始化Session的時候初始化,可以用作成員變量;
而在加載模型的時候需要單獨的初始化saver,因爲如果是train方法內實體化的saver在restore方法無法使用。 - session在什麼時候初始化?
在Model()實體化後,再實例化Session,再使用session來init所有的變量,否則會報變量沒有實體化的錯誤。