.ckpt與.pb互相轉換

將預先訓練的.ckpt模型轉換爲.pb(protobuf)格式:

import os
import tensorflow as tf

# Get the current directory
dir_path = os.path.dirname(os.path.realpath(__file__))
print
["Current directory : ", dir_path]
save_dir = dir_path + '/.ipynb_checkpoints'

graph = tf.get_default_graph()

# Create a session for running Ops on the Graph.
sess = tf.Session()

print("Restoring the model to the default graph ...")
saver = tf.train.import_meta_graph(dir_path + '/yolov3_coco.ckpt.meta')
saver.restore(sess,tf.train.latest_checkpoint(dir_path))
print("Restoring Done .. ")

print
["Saving the model to Protobuf format: ", save_dir]

#Save the model to protobuf  (pb and pbtxt) file.
tf.train.write_graph(sess.graph_def, save_dir, "Binary_Protobuf.pb", False)
tf.train.write_graph(sess.graph_def, save_dir, "Text_Protobuf.pbtxt", True)
print("Saving Done .. ")

加載protobuf文件並將其轉換爲.ckpt(checkpoint)格式?

import tensorflow as tf
import argparse 

# Pass the filename as an argument
parser = argparse.ArgumentParser()
parser.add_argument("--frozen_model_filename", default="/path-to-pb-file/Binary_Protobuf.pb", type=str, help="Pb model file to import")
args = parser.parse_args()

    # We load the protobuf file from the disk and parse it to retrieve the 
    # unserialized graph_def
with tf.gfile.GFile(args.frozen_model_filename, "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

    #saver=tf.train.Saver()
    with tf.Graph().as_default() as graph:

        tf.import_graph_def(
            graph_def,
            input_map=None,
            return_elements=None,
            name="prefix",
            op_dict=None,
            producer_op_list=None
        )
        sess = tf.Session(graph=graph)
        saver=tf.train.Saver()
        save_path = saver.save(sess, "path-to-ckpt/model.ckpt")
         print("Model saved to chkp format")

保存訓練好的模型的代碼如下:

sess = tf.Session()
saver = tf.train.Saver()
model_path = "D:\sample\model.ckpt"
save_path = saver.save(sess, model_path)

使用時,代碼如下:

saver = tf.train.Saver()
saver.restore(sess, "D:\sample\model.ckpt")
result = sess.run(y, feed_dict={x: data})

y即爲輸出的結果。

github傳送門:SymphonyPy/Valified_Code_Classify,一個識別非常簡單的驗證碼的程序

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章