關鍵代碼,這裏的模型爲tf.keras 模型:
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
full_model = tf.function(lambda x: model(x))
full_model = full_model.get_concrete_function(
tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))
# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
logdir="",
name="frozen_graph.pb",
as_text=False)
PS:建議在tensorflow1.14版本以上進行轉換。
完整代碼:
# Copyright 2020 chenli Authors. All Rights Reserved.
import tensorflow as tf
import os
from mobilenetv3_factory import build_mobilenetv3
def mobilenetv3_creat(num_classes=2, img_size=128):
"""模型創建
:param:
num_classes:類別數
img_size(int):圖像大小
:return:
model:模型
"""
# read mobilenetv3 model
model = build_mobilenetv3(
"small",
input_shape=(img_size, img_size, 3),
num_classes=num_classes,
width_multiplier=1.0,
l2_reg=1e-5,
)
return model
def mobilenetv3_load_weight(model, model_path="mobilenetv3_small_10.h5"):
""""加載權重
:param:
model:模型
model_path(str):模型權重路徑
:return:
model:模型
"""
model.load_weights(model_path)
return model
if __name__ == "__main__":
tf.enable_eager_execution()
# tf.keras.set_learning_phase(0)
classes = ['ClientFace', 'ImposterFace']
num_classes = len(classes)
img_size = 128
model_path = "mobilenetv3_small_10.h5"
model = mobilenetv3_creat(num_classes=num_classes, img_size=img_size)
model = mobilenetv3_load_weight(model, model_path=model_path)
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
full_model = tf.function(lambda x: model(x))
full_model = full_model.get_concrete_function(
tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))
# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
logdir="",
name="frozen_graph.pb",
as_text=False)