tesnsorflow部署.pb

免責聲明:本文僅代表個人觀點,如有錯誤,請讀者自己鑑別;如果本文不小心含有別人的原創內容,請聯繫我刪除;本人心血製作,若轉載請註明出處

1、保存模型爲.ckpt文件

saver = tf.train.Saver()
save_path = saver.save(sess, model_path)# 保存模型 其中model_path爲模型保存的文件
保存後的模型有四個文件 checkpoint、SFCN.ckpt.data-00000-of-00001、SFCN.ckpt.index、SFCN.ckpt.meta
2、保存event, event爲事件的保存路徑

summary_writer = tf.summary.FileWriter(event, graph=sess.graph)

在命令窗口中輸入 tensorboard --logdir==事件路徑

在瀏覽器中可以可視化tensorboard,可以可視化tensorboard,查看圖中有哪些節點

3、將ckpt文件轉化爲.pb文件,注意output_node_names = "Placeholder,Placeholder_2,keep_probabilty,conv2d_transpose_3" 這一行要寫入所有點節點,注意節點之間不可以加空格

import tensorflow as tf
from tensorflow.python.framework import graph_util
def freeze_graph(input_checkpoint, output_graph):
    '''
    :param input_checkpoint:
    :param output_graph: PB模型保存路徑
    :return:
    '''
    # 指定輸出的節點名稱,該節點名稱必須是原模型中存在的節點
    # 直接用最後輸出的節點,可以在tensorboard中查找到,tensorboard只能在linux中使用
    output_node_names = "Placeholder,Placeholder_2,keep_probabilty,conv2d_transpose_3"
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
    graph = tf.get_default_graph()  # 獲得默認的圖
    input_graph_def = graph.as_graph_def()  # 返回一個序列化的圖代表當前的圖

    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)  # 恢復圖並得到數據
        output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,將變量值固定
            sess=sess,
            input_graph_def=input_graph_def,  # 等於:sess.graph_def
            output_node_names=output_node_names.split(","))  # 如果有多個輸出節點,以逗號隔開

        with tf.gfile.GFile(output_graph, "wb") as f:  # 保存模型
            f.write(output_graph_def.SerializeToString())  # 序列化輸出
        print("%d ops in the final graph." % len(output_graph_def.node))  # 得到當前圖有幾個操作節點


input_checkpoint = "./checkpoint/SFCN.ckpt"  # 輸入的ckpt文件位置
output_graph = "node.pb"  # 輸出節點的文件名
freeze_graph(input_checkpoint, output_graph)

4、調用.pb文件,注意y_in = sess.graph.get_tensor_by_name("Placeholder:0") images_in = sess.graph.get_tensor_by_name("Placeholder_2:0") keep_probability_in = sess.graph.get_tensor_by_name("keep_probabilty:0") logits_out = sess.graph.get_tensor_by_name("conv2d_transpose_3:0")

這四個節點要與第3步中的節點保持一致,但是要在後面加入“:0”,如第3步中節點爲“Placeholder”,第4步要寫爲“Placeholder:0”

global graph
graph = tf.get_default_graph()
with graph.as_default():
    output_graph_def = tf.GraphDef()
    with open(model_path, "rb") as f:
        output_graph_def.ParseFromString(f.read())
        tf.import_graph_def(output_graph_def, name="")
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        y_in = sess.graph.get_tensor_by_name("Placeholder:0")
        images_in = sess.graph.get_tensor_by_name("Placeholder_2:0")
        keep_probability_in = sess.graph.get_tensor_by_name("keep_probabilty:0")
        logits_out = sess.graph.get_tensor_by_name("conv2d_transpose_3:0")
        batch_size = 1
        testTime = 0
        predictLabel = tf.zeros(trainLabel.shape)
        predictLabel = sess.run(predictLabel)
        for i in range(0, smallImage):
            realbatch_array, real_labels, real_index = getNext_batch(trainData, trainLabel, trainIndex, i)
            testStart = time.time()
            yy = sess.run(logits_out,
                          feed_dict={images_in: realbatch_array, y_in: real_labels, keep_probability_in: 1.0})
            predictLabel[i, ...] = yy
            testEnd = time.time()
            testTime1 = testEnd - testStart
            testTime += testTime1

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章