tf.train.write_graph用法

我不是知識的生產者,我只是一個渺小的搬運工,我們都站在巨人的肩膀上


探索了一下午這玩意的用法,終於會用了,在此附上實例子

首先要明白他保存圖的原理,這個裏面講的很詳細,請細品

https://zhuanlan.zhihu.com/p/31308381

tf.train.write_graph這個函數可以保留節點,op,constant,但不保存variable,如果你想要保存variable,那麼就要轉爲constant

import tensorflow as tf
import numpy as np
from tensorflow.python.platform import gfile

#生成圖
input1= tf.placeholder(tf.int32,name="input")
b = tf.constant([3])
output1= tf.add(input1, b, name="output")

#保存圖
with tf.Session() as sess:
    tf.train.write_graph(sess.graph_def, "./", "test.pb", False)
    print(sess.run(output1,feed_dict={input1:1}))

#讀取圖
with tf.Session() as sess:
    with gfile.FastGFile("./test.pb",'rb') as f:
        graph = tf.get_default_graph()
        graph_def = graph.as_graph_def()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')


#查看圖中信息,填充運行圖
with tf.Session() as sess:
    input_x1 = sess.graph.get_tensor_by_name("input:0")  
    print (input_x1)   #可以看到這個placeholder的屬性
    output = sess.graph.get_tensor_by_name("output:0")
    print (output)
    data1 = int(3)
    print(sess.run(output,feed_dict={input_x1:data1}))  #填充placeholder,然後運行圖

#或者也可以直接讀入圖,運行
data1 = int(3)
with tf.Session() as sess:
    with gfile.FastGFile("./test.pb",'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        output = tf.import_graph_def(graph_def, input_map={'input:0':data1},             
        return_elements=['output:0'], name='a') 
        print(sess.run(output))  

print(len(graph_def.node))  #打印所有的op數

tensor_name = [tensor.name for tensor in graph_def.node]
print(tensor_name)   #打印所有的tensor名字

for op in graph.get_operations():
    print(op.name, op.values())  # print出tensor的name和值

 

同時tensorboard給我幫助我們模型結構可視化

在讀取文件時tf.summary.FileWriter保存

with tf.Session() as sess:
    with gfile.FastGFile("./test.pb",'rb') as f:
        graph = tf.get_default_graph()
        graph_def = graph.as_graph_def()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='graph')
        summaryWriter = tf.summary.FileWriter('log/', graph)  #存log

然後在終端下運行,省略寫法的話一般是自動會自動補充端口號6006

tensorboard --logdir log --host localhost --port 6006 或者

tensorboard --logdir log   #省略寫法

就會在終端生成一個可視化的連接了

 

 

解釋幾個問題:

1.sess.graph.get_tensor_by_name("input:0")是幹什麼的,爲什麼是input:0?

   答:是幫你獲取張量的,input是節點名稱,input:0是表述節點的輸出的第一個張量

2.如果圖中有變量,也想要保存,怎麼辦?

   答:保存圖的時候轉化成常量保存,graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])

最後:有問題歡迎指正聯繫

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章