Tensorflow中變量保存與恢復


Saver

Tensorflow中,用 tensorflow.train.Saver來保存、恢復變量。

保存變量

tf.train.Saver.save(sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix='meta', write_meta_graph=True, write_state=True)
#————————————————————例子————————————————————————
import tensorflow as tf
# 創建兩個變量
v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")

# 添加用於初始化變量的節點
init_op = tf.global_variables_initializer()

# Create a saver.
saver = tf.train.Saver(tf.global_variables())

# 運行,保存變量
sess = tf.Session()
saver.save(sess,'my-model')

Saver可以使用提供的計數器自動爲checkpoint文件編號。這使得在訓練模型時在不同的步驟保留多個檢查點。在save()方法中傳遞可選的global_step參數,可以對checkpoint文件進行編號

saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'
#——————————————————例子————————————————————————
import tensorflow as tf
# 創建兩個變量
v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")

# 添加用於初始化變量的節點
init_op = tf.global_variables_initializer()

# Create a saver.
saver = tf.train.Saver(tf.global_variables())

# 運行圖,打開會話,每1000次保存一個模型
sess = tf.Session()
for step in range(10000):
    sess.run(init_op)
    if step % 1000 == 0:
        saver.save(sess, base_path+'my-model', global_step=step)

運行結果:
生成保存的模型變量

恢復變量

tf.train.Saver.restore(sess, save_path)
#————————————————例子——————————————————————
 sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables())
        #arg:獲取最近一次保存的變量文件名稱
        module_file = tf.train.latest_checkpoint('my-model')
        print(module_file)
        saver.restore(sess, module_file)

欲瞭解saver更詳細的內容,請戳tensorflow.saver

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章