# -*- coding:utf-8 -*-
"""
This file used to freeze tensorflow .ckpt to .pb
"""
import tensorflow as tf
#兩種方式 方法1:函數方法,傳入session
def freeze_session(session, keep_var_name=None, output_names=None, clear_devices=True):
graph = session.graph
with graph.as_default():
#difference方法 返回的值在global_variables中單不在keep_var_name中
freeze_var_name = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_name or []))
output_names = output_names or []
output_names += [v.op.name for v in tf.global_variables()]
input_graph_def = graph.as_graph_def()
if clear_devices:
for node in input_graph_def.node:
node.decive = ''
frozen_graph = tf.graph_util.convert_variables_to_constants(session, input_graph_def, output_names, freeze_var_name)
return frozen_graph
#方法1第二步 將凍結的模型保存爲pb格式
#其中上一個函數中 output_name保存的就是節點
session = tf.Session()
net_model = '讀取網絡模型'
output_path = ''
pb_model_name = 'xxxx.pb'
frozen_graph = freeze_session(session, output_names=[net_model.output.op.name])
tf.python.framework.graph_io.write_graph(frozen_graph, output_path, pb_model_name, as_text=False)
#------------------------------------------------------------------------------------------------------------
#方法二 直接凍結
#1.指定模型輸出
output_nodes = tf.global_variables()
#utput_nodes = ["Accuracy/prediction", "Metric/Dice"] 指定模型輸出, 這樣可以允許自動裁剪無關節點. 這裏認爲使用逗號分割
#加載模型
saver = tf.train.import_meta_graph('model.ckpt.meta', clear_devices=True)
with tf.Session(graph=tf.get_default_graph()) as sess:
#序列化模型
input_graph_def = sess.graph.as_graph_def()
#載入權重
saver.restore(sess, 'model.ckpt')
#轉換變量爲常量
output_graph_def = tf.graph_util.convert_variables_to_constants(sess, input_graph_def, output_nodes)
#寫入pb文件
with open('frozen_model.pb', 'wb') as f:
f.write(output_graph_def.SerializeToString())