tf.Keras 保存爲pb文件

折騰了我幾天,一直搞不定。最後用以下代碼成功保存。

方法一:

tensorflow2.0以上版本可以使用

tf.saved_model.save(model, "save_test")
model = tf.saved_model.load("save_test")

來保存成pb文件,以及讀取。

方法二

tensorflow1.x版本可以使用如下代碼保存:

    session = tf.keras.backend.get_session()
    model_name = 'my_model'
    builder = tf.saved_model.builder.SavedModelBuilder(model_name)
    builder.add_meta_graph_and_variables(session, ["my_model"])
    builder.save()

    model_name = 'my_model'
    with tf.Session(graph=tf.Graph()) as sess:
        tf.saved_model.loader.load(sess, ["my_model"], model_name)

    # with tf.Session() as sess:
    #     init = tf.global_variables_initializer()
    #     sess.run(init)

        op = sess.graph.get_operations()

        # 打印圖中有的操作
        for i, m in enumerate(op):
            print('op{}:'.format(i), m.values())

        input_x = sess.graph.get_tensor_by_name("input_1:0")  # 具體名稱看上一段代碼的input.name
        print("input_X:", input_x)

        out_softmax = sess.graph.get_tensor_by_name(
            "MobileNetV3_Small/LastStage/Squeeze/Squeeze_1:0")  # 具體名稱看上一段代碼的output.name
        print("Output:", out_softmax)

        # 讀入圖片
        img = cv2.imread("1.jpg")
        img = cv2.resize(img, (128, 128))
        img = img.astype(np.float32)
        # img = 1 - img / 255;
        # img=np.reshape(img,(1,28,28,1))
        print("img data type:", img.dtype)

        img_out_softmax = sess.run(out_softmax,
                                   feed_dict={input_x: np.reshape(img, (1, 128, 128, 3))})
        print("img_out_softmax:", img_out_softmax)
        for i, prob in enumerate(img_out_softmax[0]):
            print('class {} prob:{}'.format(i, prob))
        prediction_labels = np.argmax(img_out_softmax, axis=1)
        print("Final class if:", prediction_labels)
        print("prob of label:", img_out_softmax[0, prediction_labels])

方法三:

還有一種方法可以保存,但是有些模型讀取時會出錯:

    def freeze_graph(graph, session, output_node_names, model_name):
        with graph.as_default():
        graphdef_inf = tf.graph_util.remove_training_nodes(graph.as_graph_def())
        graphdef_frozen = tf.graph_util.convert_variables_to_constants(session, graphdef_inf, output_node_names)
        graph_io.write_graph(graphdef_frozen, "", os.path.basename(model_name) + ".pb", as_text=False)

    tf.keras.backend.set_learning_phase(0)  # this line most important
    model_name = 'my_model2.pb'
    session = tf.keras.backend.get_session()
    freeze_graph(session.graph, session, [out.op.name for out in model.outputs], model_name)
def recognize(jpg_path, pb_file_path):
    with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()

        # 打開.pb模型
        with open(pb_file_path, "rb") as f:
            output_graph_def.ParseFromString(f.read())
            
        with tf.Graph().as_default() as graph:
            tf.import_graph_def(output_graph_def, name='1')
            # print("tensors:", tensors)

        # 在一個session中去run一個前向
        with tf.Session() as sess:
            init = tf.global_variables_initializer()
            sess.run(init)

            op = sess.graph.get_operations()

            # 打印圖中有的操作
            for i,m in enumerate(op):
                print('op{}:'.format(i),m.values())

            input_x = sess.graph.get_tensor_by_name("input_1:0")  # 具體名稱看上一段代碼的input.name
            print("input_X:", input_x)

            out_softmax = sess.graph.get_tensor_by_name("MobileNetV3_Small/LastStage/Squeeze/Squeeze_1:0")  # 具體名稱看上一段代碼的output.name
            print("Output:",out_softmax)

            # 讀入圖片
            img = cv2.imread(jpg_path, 0)
            img=cv2.resize(img,(128,128,3))
            img=img.astype(np.float32)
            img=1-img/255;
            # img=np.reshape(img,(1,28,28,1))
            print("img data type:",img.dtype)


            img_out_softmax = sess.run(out_softmax,
                                       feed_dict={input_x: np.reshape(img,(1,128,128,3))})

            print("img_out_softmax:", img_out_softmax)
            for i,prob in enumerate(img_out_softmax[0]):
                print('class {} prob:{}'.format(i,prob))
            prediction_labels = np.argmax(img_out_softmax, axis=1)
            print("Final class if:",prediction_labels)
            print("prob of label:",img_out_softmax[0,prediction_labels])

參考文獻:

https://zhuanlan.zhihu.com/p/55600911

https://blog.csdn.net/qq_25109263/article/details/81285952

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章