Saver的用法

1. Saver的背景介紹

    我們經常在訓練完一個模型之後希望保存訓練的結果,這些結果指的是模型的參數,以便下次迭代的訓練或者用作測試。Tensorflow針對這一需求提供了Saver類。
  1. Saver類提供了向checkpoints文件保存和從checkpoints文件中恢復變量的相關方法。Checkpoints文件是一個二進制文件,它把變量名映射到對應的tensor值 。
  2. 只要提供一個計數器,當計數器觸發時,Saver類可以自動的生成checkpoint文件。這讓我們可以在訓練過程中保存多箇中間結果。例如,我們可以保存每一步訓練的結果。
  3. 爲了避免填滿整個磁盤,Saver可以自動的管理Checkpoints文件。例如,我們可以指定保存最近的N個Checkpoints文件。

2. Saver的實例

下面以一個例子來講述如何使用Saver類

[python] view plain copy
  1. import tensorflow as tf  
  2. import numpy as np  
  3.   
  4. x = tf.placeholder(tf.float32, shape=[None1])  
  5. y = 4 * x + 4  
  6.   
  7. w = tf.Variable(tf.random_normal([1], -11))  
  8. b = tf.Variable(tf.zeros([1]))  
  9. y_predict = w * x + b  
  10.   
  11.   
  12. loss = tf.reduce_mean(tf.square(y - y_predict))  
  13. optimizer = tf.train.GradientDescentOptimizer(0.5)  
  14. train = optimizer.minimize(loss)  
  15.   
  16. isTrain = False  
  17. train_steps = 100  
  18. checkpoint_steps = 50  
  19. checkpoint_dir = ''  
  20.   
  21. saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b  
  22. x_data = np.reshape(np.random.rand(10).astype(np.float32), (101))  
  23.   
  24. with tf.Session() as sess:  
  25.     sess.run(tf.initialize_all_variables())  
  26.     if isTrain:  
  27.         for i in xrange(train_steps):  
  28.             sess.run(train, feed_dict={x: x_data})  
  29.             if (i + 1) % checkpoint_steps == 0:  
  30.                 saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1)  
  31.     else:  
  32.         ckpt = tf.train.get_checkpoint_state(checkpoint_dir)  
  33.         if ckpt and ckpt.model_checkpoint_path:  
  34.             saver.restore(sess, ckpt.model_checkpoint_path)  
  35.         else:  
  36.             pass  
  37.         print(sess.run(w))  
  38.         print(sess.run(b))  

isTrain:用來區分訓練階段和測試階段,True表示訓練,False表示測試
train_steps:表示訓練的次數,例子中使用100
checkpoint_steps:表示訓練多少次保存一下checkpoints,例子中使用50
checkpoint_dir:表示checkpoints文件的保存路徑,例子中使用當前路徑


2.1 訓練階段

使用Saver.save()方法保存模型:
  1. sess:表示當前會話,當前會話記錄了當前的變量值
  2. checkpoint_dir + 'model.ckpt':表示存儲的文件名
  3. global_step:表示當前是第幾步
訓練完成後,當前目錄底下會多出5個文件。

    打開名爲“checkpoint”的文件,可以看到保存記錄,和最新的模型存儲位置。

2.1測試階段

    測試階段使用saver.restore()方法恢復變量:
  1. sess:表示當前會話,之前保存的結果將被加載入這個會話
  2. ckpt.model_checkpoint_path:表示模型存儲的位置,不需要提供模型的名字,它會去查看checkpoint文件,看看最新的是誰,叫做什麼。
    運行結果如下圖所示,加載了之前訓練的參數w和b的結果
發佈了44 篇原創文章 · 獲贊 79 · 訪問量 41萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章