TensorFlow模型凍結(ckpt轉爲pb)筆記

# -*- coding:utf-8 -*-

"""
 This file used to freeze tensorflow .ckpt to .pb
"""

import tensorflow as tf


#兩種方式 方法1:函數方法,傳入session
def freeze_session(session, keep_var_name=None, output_names=None, clear_devices=True):
    graph = session.graph
    with graph.as_default():
        #difference方法 返回的值在global_variables中單不在keep_var_name中
        freeze_var_name = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_name or []))
        output_names = output_names or []
        output_names += [v.op.name for v in tf.global_variables()]
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.decive = ''

        frozen_graph = tf.graph_util.convert_variables_to_constants(session, input_graph_def, output_names, freeze_var_name)
        return frozen_graph

#方法1第二步 將凍結的模型保存爲pb格式
#其中上一個函數中 output_name保存的就是節點
session = tf.Session()
net_model = '讀取網絡模型'
output_path = ''
pb_model_name = 'xxxx.pb'
frozen_graph = freeze_session(session, output_names=[net_model.output.op.name])
tf.python.framework.graph_io.write_graph(frozen_graph, output_path, pb_model_name, as_text=False)

#------------------------------------------------------------------------------------------------------------
#方法二 直接凍結
#1.指定模型輸出
output_nodes = tf.global_variables()
#utput_nodes = ["Accuracy/prediction", "Metric/Dice"] 指定模型輸出, 這樣可以允許自動裁剪無關節點. 這裏認爲使用逗號分割

#加載模型
saver = tf.train.import_meta_graph('model.ckpt.meta', clear_devices=True)

with tf.Session(graph=tf.get_default_graph()) as sess:
    #序列化模型
    input_graph_def = sess.graph.as_graph_def()
    #載入權重
    saver.restore(sess, 'model.ckpt')
    #轉換變量爲常量
    output_graph_def = tf.graph_util.convert_variables_to_constants(sess, input_graph_def, output_nodes)

    #寫入pb文件
    with open('frozen_model.pb', 'wb') as f:
        f.write(output_graph_def.SerializeToString())
        

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章