模型應用
模型部署應用首選Tensorflow,Tensorflow模型部署使用pb格式最爲簡單。
本文以圖像分類模型爲例,介紹pb模型使用方法:
- 代碼
import cv2
import tensorflow as tf
import numpy as np
import sys, os
class Recognizer():
def __init__(self, pb_path):
self.pb_path = pb_path
self.config = tf.ConfigProto()
self.config.gpu_options.allow_growth = True
self.init_model()
def init_model(self):
tf.Graph().as_default()
self.output_graph_def = tf.GraphDef()
with open(self.pb_path, 'rb') as f:
self.output_graph_def.ParseFromString(f.read())
tf.import_graph_def(
self.output_graph_def,
input_map = None,
return_elements = None,
name = None,
op_dict = None,
producer_op_list = None
)
self.sess = tf.Session(config = self.config)
self.input = self.sess.graph.get_tensor_by_name("input_1:0")#自己定義的輸入tensor名稱
self.output = self.sess.graph.get_tensor_by_name("output_1:0")#自己定義的輸出tensor名稱
def predict(self, img):
img = (img - 255/2.0) / 255
img = img[np.newaxis, :, :, :]
res = self.sess.run(self.output, feed_dict={self.input: img})
class_id = np.argmax(res)
return str(class_id)
def batch_predict(self, img_list):
class_ids = []
for img in img_list:
class_id = self.predict(img)
class_ids.append(class_id)
return class_ids
if __name__ == '__main__':
if len(sys.argv) == 3:
recognizer = Recognizer(pb_path=sys.argv[1])
img = cv2.imread(sys.argv[2])
res = recognizer.predict(img)
print('result:', res)
注意:
1、input、output的tensor名稱是網絡中自己定義的名稱,未定義則默認爲inut_1:0、output_1:0