參考的教程
https://blog.csdn.net/YiRanZhiLiPoSui/article/details/81143166
參考入門文章:
https://blog.csdn.net/u012436149/article/details/53341372
給出了簡單的完整流程,便於入門理解
https://www.jianshu.com/p/7490ebfa3de8
tensorflow官網出的Supervisor介紹 的中文翻譯版:長期訓練好幫手
https://www.tensorflow.org/versions/r1.1/programmers_guide/supervisor
tensorflow官網出的Supervisor介紹
https://www.tensorflow.org/api_docs/python/tf/train/Supervisor
官方的Supervisor接口文檔。不過缺乏完整的例子。
一、不使用Supervisor的情況
在不使用Supervisor
的時候,我們的代碼經常是這麼組織的
variables
...
ops
...
summary_op
...
merge_op = tf.summary.merge_all()
saver
init_op
with tf.Session() as sess:
writer = tf.summary.FileWriter()
sess.run(init)
saver.restore()
for ...:
train
merged_summary = sess.run(merge_op)
writer.add_summary(merged_summary,i)
saver.save
二、使用Supervisor的情況
使用一個logdir目錄 來同時保存 模型圖 和 權重參數
sv = tf.train.Supervisor(logdir=logs_path,init_op=init_op,summary_op=None) #logdir用來保存checkpoint和summary
注意有個參數是summary_op
如果沒有summary_op=None,則使用Supervisor自帶的summary服務
使用sv = tf.train.Supervisor() 會自動初始化。
無參數也可以,最好加上logdir,同時,兩個logdir可以不同
import tensorflow as tf
tf.reset_default_graph()
a = tf.Variable(1)
b = tf.Variable(2)
c = tf.add(a,b)
update = tf.assign(a,c)
logs_path='./logaa'
'''不需要初始化'''
#init_op = tf.global_variables_initializer()
#sv = tf.train.Supervisor(logdir=logs_path,init_op=init_op) #logdir用來保存checkpoint和summary
'''這樣也可以,最好加上logdir'''
sv = tf.train.Supervisor(logdir=logs_path) #這樣也可以
with sv.managed_session() as sess: #會自動去logdir中去找checkpoint,如果沒有的話,自動執行初始化
for i in range(71):
update_ = sess.run(update)
print(update_)
# if i % 10 == 0:
# merged_summary = sess.run(merged_summary_op)
# sv.summary_computed(sess, merged_summary)
if i%10 == 0:
sv.saver.save(sess,logs_path+'/model',global_step=i)
如果有summary_op=None,則需自建summary服務
import tensorflow as tf
tf.reset_default_graph()
a = tf.Variable(1)
b = tf.Variable(2)
c = tf.add(a,b)
update = tf.assign(a,c)
logs_path='./logaa/'
tf.summary.scalar('a', a)
init_op = tf.global_variables_initializer()
merged_summary_op = tf.summary.merge_all()
sv = tf.train.Supervisor(logdir=logs_path,init_op=init_op,summary_op=None) #logdir用來保存checkpoint和summary
with sv.managed_session() as sess: #會自動去logdir中去找checkpoint,如果沒有的話,自動執行初始化
for i in range(1000):
update_ = sess.run(update)
print(update_)
if i % 10 == 0:
merged_summary = sess.run(merged_summary_op)
sv.summary_computed(sess, merged_summary)
if i%100 == 0:
sv.saver.save(sess,logs_path,global_step=i)
一個完整的例子
# -*- coding: utf-8 -*-
import tensorflow as tf
tf.reset_default_graph()
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
##### 構建圖結構
# 定義輸入:x和y
x = tf.placeholder(tf.float32, [None, 784], name='input_x')
y_ = tf.placeholder(tf.float32, [None, 10], name='input_y')
# 定義權重參數
W = tf.Variable(tf.truncated_normal([784, 10], stddev=0.1), name='weights')
b = tf.Variable(tf.constant(0.1, shape=[10]), name='bias')
# 定義模型
y_output = tf.nn.softmax(tf.matmul(x, W) + b)
# 定義交叉熵
cross_entropy = -tf.reduce_sum(y_ * tf.log(y_output))
# 監控交叉熵
tf.summary.scalar('loss', cross_entropy)
# tf.summary.scalar('loss', cross_entropy, collections=['loss'])
# 定義優化器和訓練器
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
# 定義準確率的計算方式
# 取預測值和真實值 概率最大的標籤
correct_prediction = tf.equal(tf.argmax(y_output,1), tf.argmax(y_,1))
accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
##### 構建會話
# 定義log保存路徑
logs_path = 'logsbbb/'
# 定義summary node集合
merged_summary_op = tf.summary.merge_all()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
# 定義Supervisor
sv = tf.train.Supervisor(logdir=logs_path, init_op=tf.global_variables_initializer(), summary_op=None)
with sv.managed_session(config=config) as sess :
# 超參數
ITERATION = 1000 +1
BATCH_SIZE = 64
ITERATION_SHOW = 100
for step in range(ITERATION) :
# 執行訓練op
batch = mnist.train.next_batch(BATCH_SIZE)
sess.run(train_step, feed_dict={x: batch[0], y_: batch[1]})
if step%ITERATION_SHOW == 0:
# 計算當前訓練樣本的準確率
merged_summary, accuracy = sess.run([merged_summary_op, accuracy_op], feed_dict={x: batch[0], y_: batch[1]})
sv.summary_computed(sess, merged_summary, global_step=step)
# 輸出當前準確率
print("step %d, accuarcy:%.4g" % (step, accuracy))
# 保存模型
sv.saver.save(sess, logs_path, global_step=step)