折騰了我幾天,一直搞不定。最後用以下代碼成功保存。
方法一:
tensorflow2.0以上版本可以使用
tf.saved_model.save(model, "save_test")
model = tf.saved_model.load("save_test")
來保存成pb文件,以及讀取。
方法二
tensorflow1.x版本可以使用如下代碼保存:
session = tf.keras.backend.get_session()
model_name = 'my_model'
builder = tf.saved_model.builder.SavedModelBuilder(model_name)
builder.add_meta_graph_and_variables(session, ["my_model"])
builder.save()
model_name = 'my_model'
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess, ["my_model"], model_name)
# with tf.Session() as sess:
# init = tf.global_variables_initializer()
# sess.run(init)
op = sess.graph.get_operations()
# 打印圖中有的操作
for i, m in enumerate(op):
print('op{}:'.format(i), m.values())
input_x = sess.graph.get_tensor_by_name("input_1:0") # 具體名稱看上一段代碼的input.name
print("input_X:", input_x)
out_softmax = sess.graph.get_tensor_by_name(
"MobileNetV3_Small/LastStage/Squeeze/Squeeze_1:0") # 具體名稱看上一段代碼的output.name
print("Output:", out_softmax)
# 讀入圖片
img = cv2.imread("1.jpg")
img = cv2.resize(img, (128, 128))
img = img.astype(np.float32)
# img = 1 - img / 255;
# img=np.reshape(img,(1,28,28,1))
print("img data type:", img.dtype)
img_out_softmax = sess.run(out_softmax,
feed_dict={input_x: np.reshape(img, (1, 128, 128, 3))})
print("img_out_softmax:", img_out_softmax)
for i, prob in enumerate(img_out_softmax[0]):
print('class {} prob:{}'.format(i, prob))
prediction_labels = np.argmax(img_out_softmax, axis=1)
print("Final class if:", prediction_labels)
print("prob of label:", img_out_softmax[0, prediction_labels])
方法三:
還有一種方法可以保存,但是有些模型讀取時會出錯:
def freeze_graph(graph, session, output_node_names, model_name):
with graph.as_default():
graphdef_inf = tf.graph_util.remove_training_nodes(graph.as_graph_def())
graphdef_frozen = tf.graph_util.convert_variables_to_constants(session, graphdef_inf, output_node_names)
graph_io.write_graph(graphdef_frozen, "", os.path.basename(model_name) + ".pb", as_text=False)
tf.keras.backend.set_learning_phase(0) # this line most important
model_name = 'my_model2.pb'
session = tf.keras.backend.get_session()
freeze_graph(session.graph, session, [out.op.name for out in model.outputs], model_name)
def recognize(jpg_path, pb_file_path):
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
# 打開.pb模型
with open(pb_file_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(output_graph_def, name='1')
# print("tensors:", tensors)
# 在一個session中去run一個前向
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
op = sess.graph.get_operations()
# 打印圖中有的操作
for i,m in enumerate(op):
print('op{}:'.format(i),m.values())
input_x = sess.graph.get_tensor_by_name("input_1:0") # 具體名稱看上一段代碼的input.name
print("input_X:", input_x)
out_softmax = sess.graph.get_tensor_by_name("MobileNetV3_Small/LastStage/Squeeze/Squeeze_1:0") # 具體名稱看上一段代碼的output.name
print("Output:",out_softmax)
# 讀入圖片
img = cv2.imread(jpg_path, 0)
img=cv2.resize(img,(128,128,3))
img=img.astype(np.float32)
img=1-img/255;
# img=np.reshape(img,(1,28,28,1))
print("img data type:",img.dtype)
img_out_softmax = sess.run(out_softmax,
feed_dict={input_x: np.reshape(img,(1,128,128,3))})
print("img_out_softmax:", img_out_softmax)
for i,prob in enumerate(img_out_softmax[0]):
print('class {} prob:{}'.format(i,prob))
prediction_labels = np.argmax(img_out_softmax, axis=1)
print("Final class if:",prediction_labels)
print("prob of label:",img_out_softmax[0,prediction_labels])
參考文獻:
https://zhuanlan.zhihu.com/p/55600911
https://blog.csdn.net/qq_25109263/article/details/81285952