h5 轉 pb文件

關鍵代碼,這裏的模型爲tf.keras 模型:

from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

    full_model = tf.function(lambda x: model(x))
    full_model = full_model.get_concrete_function(
        tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))

    # Get frozen ConcreteFunction
    frozen_func = convert_variables_to_constants_v2(full_model)
    frozen_func.graph.as_graph_def()
    tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                      logdir="",
                      name="frozen_graph.pb",
                      as_text=False)

PS:建議在tensorflow1.14版本以上進行轉換。

完整代碼:

# Copyright 2020 chenli Authors. All Rights Reserved.
import tensorflow as tf
import os
from mobilenetv3_factory import build_mobilenetv3


def mobilenetv3_creat(num_classes=2, img_size=128):
    """模型創建
    :param:
    num_classes:類別數
    img_size(int):圖像大小
    :return:
    model:模型
    """

    # read mobilenetv3 model
    model = build_mobilenetv3(
        "small",
        input_shape=(img_size, img_size, 3),
        num_classes=num_classes,
        width_multiplier=1.0,
        l2_reg=1e-5,
    )
    return model


def mobilenetv3_load_weight(model, model_path="mobilenetv3_small_10.h5"):
    """"加載權重
    :param:
    model:模型
    model_path(str):模型權重路徑
    :return:
    model:模型
    """
    model.load_weights(model_path)
    return model


if __name__ == "__main__":
    tf.enable_eager_execution()
    # tf.keras.set_learning_phase(0)
    classes = ['ClientFace', 'ImposterFace']
    num_classes = len(classes)
    img_size = 128
    model_path = "mobilenetv3_small_10.h5"
    model = mobilenetv3_creat(num_classes=num_classes, img_size=img_size)

    model = mobilenetv3_load_weight(model, model_path=model_path)

    from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

    full_model = tf.function(lambda x: model(x))
    full_model = full_model.get_concrete_function(
        tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))

    # Get frozen ConcreteFunction
    frozen_func = convert_variables_to_constants_v2(full_model)
    frozen_func.graph.as_graph_def()
    tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                      logdir="",
                      name="frozen_graph.pb",
                      as_text=False)

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章