導出的pb模型可使用tf.contrib.predictor很方便的進行預測,僅限於tensorflow 1.x,代碼如下:
import numpy as np
import tensorflow as tf
from tensorflow.contrib import predictor
# 加載模型,使用estimator導出的模型、tf.saved_model保存的模型都可以使用該方法
# 模型目錄文件爲:saved_model.pb variables/xxx
predictor_fn = predictor.from_saved_model("./model/")
def input_fn(line):
# 定義輸入feed_dict,輸入爲device_id
feed_dict = {"device_id": [int(line)]}
return feed_dict
# 輸入預測即可
text = "567892"
feed_dict = input_fn(text)
pred = predictor_fn(feed_dict)
print(pred)