h5 转 pb文件

关键代码,这里的模型为tf.keras 模型:

from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

    full_model = tf.function(lambda x: model(x))
    full_model = full_model.get_concrete_function(
        tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))

    # Get frozen ConcreteFunction
    frozen_func = convert_variables_to_constants_v2(full_model)
    frozen_func.graph.as_graph_def()
    tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                      logdir="",
                      name="frozen_graph.pb",
                      as_text=False)

PS:建议在tensorflow1.14版本以上进行转换。

完整代码:

# Copyright 2020 chenli Authors. All Rights Reserved.
import tensorflow as tf
import os
from mobilenetv3_factory import build_mobilenetv3


def mobilenetv3_creat(num_classes=2, img_size=128):
    """模型创建
    :param:
    num_classes:类别数
    img_size(int):图像大小
    :return:
    model:模型
    """

    # read mobilenetv3 model
    model = build_mobilenetv3(
        "small",
        input_shape=(img_size, img_size, 3),
        num_classes=num_classes,
        width_multiplier=1.0,
        l2_reg=1e-5,
    )
    return model


def mobilenetv3_load_weight(model, model_path="mobilenetv3_small_10.h5"):
    """"加载权重
    :param:
    model:模型
    model_path(str):模型权重路径
    :return:
    model:模型
    """
    model.load_weights(model_path)
    return model


if __name__ == "__main__":
    tf.enable_eager_execution()
    # tf.keras.set_learning_phase(0)
    classes = ['ClientFace', 'ImposterFace']
    num_classes = len(classes)
    img_size = 128
    model_path = "mobilenetv3_small_10.h5"
    model = mobilenetv3_creat(num_classes=num_classes, img_size=img_size)

    model = mobilenetv3_load_weight(model, model_path=model_path)

    from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

    full_model = tf.function(lambda x: model(x))
    full_model = full_model.get_concrete_function(
        tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))

    # Get frozen ConcreteFunction
    frozen_func = convert_variables_to_constants_v2(full_model)
    frozen_func.graph.as_graph_def()
    tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                      logdir="",
                      name="frozen_graph.pb",
                      as_text=False)

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章