【tensorflow 大馬哈魚】高級保存與恢復的Supervisor模塊

參考的教程

https://blog.csdn.net/YiRanZhiLiPoSui/article/details/81143166


 

參考入門文章: 
https://blog.csdn.net/u012436149/article/details/53341372 
給出了簡單的完整流程,便於入門理解

https://www.jianshu.com/p/7490ebfa3de8 
tensorflow官網出的Supervisor介紹 的中文翻譯版:長期訓練好幫手

https://www.tensorflow.org/versions/r1.1/programmers_guide/supervisor 
tensorflow官網出的Supervisor介紹

https://www.tensorflow.org/api_docs/python/tf/train/Supervisor 
官方的Supervisor接口文檔。不過缺乏完整的例子。

 


一、不使用Supervisor的情況

在不使用Supervisor的時候,我們的代碼經常是這麼組織的

variables
...
ops
...
summary_op
...
merge_op = tf.summary.merge_all() 
saver
init_op

with tf.Session() as sess:
  writer = tf.summary.FileWriter()
  sess.run(init)
  saver.restore()
  for ...:
    train
    merged_summary = sess.run(merge_op)
    writer.add_summary(merged_summary,i)
  saver.save

二、使用Supervisor的情況

使用一個logdir目錄 來同時保存 模型圖 和 權重參數

sv = tf.train.Supervisor(logdir=logs_path,init_op=init_op,summary_op=None) #logdir用來保存checkpoint和summary

注意有個參數是summary_op

如果沒有summary_op=None,則使用Supervisor自帶的summary服務

使用sv = tf.train.Supervisor()  會自動初始化。

無參數也可以,最好加上logdir,同時,兩個logdir可以不同

import tensorflow as tf
tf.reset_default_graph()
a = tf.Variable(1)
b = tf.Variable(2)
c = tf.add(a,b)
update = tf.assign(a,c)
logs_path='./logaa'

'''不需要初始化'''
#init_op = tf.global_variables_initializer()
#sv = tf.train.Supervisor(logdir=logs_path,init_op=init_op) #logdir用來保存checkpoint和summary
'''這樣也可以,最好加上logdir'''
sv = tf.train.Supervisor(logdir=logs_path)         #這樣也可以

with sv.managed_session() as sess: #會自動去logdir中去找checkpoint,如果沒有的話,自動執行初始化
    for i in range(71):
        update_ = sess.run(update)
        print(update_)
#        if i % 10 == 0:
#            merged_summary = sess.run(merged_summary_op)
#            sv.summary_computed(sess, merged_summary)
        if i%10 == 0:
            sv.saver.save(sess,logs_path+'/model',global_step=i)

如果有summary_op=None,則需自建summary服務

import tensorflow as tf
tf.reset_default_graph()
a = tf.Variable(1)
b = tf.Variable(2)
c = tf.add(a,b)
update = tf.assign(a,c)
logs_path='./logaa/'

tf.summary.scalar('a', a) 
init_op = tf.global_variables_initializer()
merged_summary_op = tf.summary.merge_all()  
sv = tf.train.Supervisor(logdir=logs_path,init_op=init_op,summary_op=None) #logdir用來保存checkpoint和summary

with sv.managed_session() as sess: #會自動去logdir中去找checkpoint,如果沒有的話,自動執行初始化
    for i in range(1000):
        update_ = sess.run(update)
        print(update_)
        if i % 10 == 0:
            merged_summary = sess.run(merged_summary_op)
            sv.summary_computed(sess, merged_summary)
        if i%100 == 0:
            sv.saver.save(sess,logs_path,global_step=i)

一個完整的例子

# -*- coding: utf-8 -*-

import tensorflow as tf
tf.reset_default_graph()
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

##### 構建圖結構
# 定義輸入:x和y
x = tf.placeholder(tf.float32, [None, 784], name='input_x')
y_ = tf.placeholder(tf.float32, [None, 10], name='input_y')

# 定義權重參數
W = tf.Variable(tf.truncated_normal([784, 10], stddev=0.1), name='weights')
b = tf.Variable(tf.constant(0.1, shape=[10]), name='bias')

# 定義模型
y_output = tf.nn.softmax(tf.matmul(x, W) + b)

# 定義交叉熵
cross_entropy = -tf.reduce_sum(y_ * tf.log(y_output))
# 監控交叉熵
tf.summary.scalar('loss', cross_entropy)
# tf.summary.scalar('loss', cross_entropy, collections=['loss'])
# 定義優化器和訓練器
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

# 定義準確率的計算方式
# 取預測值和真實值 概率最大的標籤
correct_prediction = tf.equal(tf.argmax(y_output,1), tf.argmax(y_,1))
accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))


##### 構建會話
# 定義log保存路徑
logs_path = 'logsbbb/'
# 定義summary node集合
merged_summary_op = tf.summary.merge_all()

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
# 定義Supervisor
sv = tf.train.Supervisor(logdir=logs_path, init_op=tf.global_variables_initializer(), summary_op=None)
with sv.managed_session(config=config) as sess :
    # 超參數
    ITERATION = 1000 +1
    BATCH_SIZE = 64
    ITERATION_SHOW = 100

    for step in range(ITERATION) :
        # 執行訓練op
        batch = mnist.train.next_batch(BATCH_SIZE)
        sess.run(train_step, feed_dict={x: batch[0], y_: batch[1]})

        if step%ITERATION_SHOW == 0:
            # 計算當前訓練樣本的準確率
            merged_summary, accuracy = sess.run([merged_summary_op, accuracy_op], feed_dict={x: batch[0], y_: batch[1]})
            sv.summary_computed(sess, merged_summary, global_step=step)

            # 輸出當前準確率
            print("step %d, accuarcy:%.4g" % (step, accuracy))

            # 保存模型
            sv.saver.save(sess, logs_path, global_step=step)

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章