Saver
Tensorflow中,用 tensorflow.train.Saver來保存、恢復變量。
保存變量
tf.train.Saver.save(sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix='meta', write_meta_graph=True, write_state=True)
#————————————————————例子————————————————————————
import tensorflow as tf
# 創建兩個變量
v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")
# 添加用於初始化變量的節點
init_op = tf.global_variables_initializer()
# Create a saver.
saver = tf.train.Saver(tf.global_variables())
# 運行,保存變量
sess = tf.Session()
saver.save(sess,'my-model')
Saver可以使用提供的計數器自動爲checkpoint文件編號。這使得在訓練模型時在不同的步驟保留多個檢查點。在save()方法中傳遞可選的global_step參數,可以對checkpoint文件進行編號
saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'
#——————————————————例子————————————————————————
import tensorflow as tf
# 創建兩個變量
v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")
# 添加用於初始化變量的節點
init_op = tf.global_variables_initializer()
# Create a saver.
saver = tf.train.Saver(tf.global_variables())
# 運行圖,打開會話,每1000次保存一個模型
sess = tf.Session()
for step in range(10000):
sess.run(init_op)
if step % 1000 == 0:
saver.save(sess, base_path+'my-model', global_step=step)
運行結果:
恢復變量
tf.train.Saver.restore(sess, save_path)
#————————————————例子——————————————————————
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(tf.global_variables())
#arg:獲取最近一次保存的變量文件名稱
module_file = tf.train.latest_checkpoint('my-model')
print(module_file)
saver.restore(sess, module_file)
欲瞭解saver更詳細的內容,請戳tensorflow.saver