导出的pb模型可使用tf.contrib.predictor很方便的进行预测,仅限于tensorflow 1.x,代码如下:
import numpy as np
import tensorflow as tf
from tensorflow.contrib import predictor
# 加载模型,使用estimator导出的模型、tf.saved_model保存的模型都可以使用该方法
# 模型目录文件为:saved_model.pb variables/xxx
predictor_fn = predictor.from_saved_model("./model/")
def input_fn(line):
# 定义输入feed_dict,输入为device_id
feed_dict = {"device_id": [int(line)]}
return feed_dict
# 输入预测即可
text = "567892"
feed_dict = input_fn(text)
pred = predictor_fn(feed_dict)
print(pred)