之前有人說怎麼將t2t的訓練模型部署起來,其實不難!
首先,是安裝tensorflow-model-server 可以自行百度!
然後進行下列操作:
這裏假設你已經有了訓練好的t2t模型
模型導出:
t2t-exporter \
--t2t_usr_dir=$T2T_USR_DIR \
--model=$MODEL \
--hparams_set=$HPARAMS \
--problem=$PROBLEM \
--data_dir=$DATA_DIR \
--output_dir=$TRAIN_DIR
模型部署:
tensorflow_model_base_server \
--port=9000 \
--model_name=my_model \
--model_base_path=$TRAIN_DIR/export/Servo
注意這裏的tensor2tensor版本是 1.7.x
然後進行模型的調用,也就是使用模型進行預測:
這裏是根據官方代碼修改而來:只需在請求處理函數中調用下面的這個函數就行!
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from oauth2client.client import GoogleCredentials
from six.moves import input # pylint: disable=redefined-builtin
from tensor2tensor import problems as problems_lib # pylint: disable=unused-import
from tensor2tensor.serving import serving_utils
from tensor2tensor.utils import registry
from tensor2tensor.utils import usr_dir
import tensorflow as tf
def make_request_fn(server_name, server_address):
"""Returns a request function."""
request_fn = serving_utils.make_grpc_request_fn(
servable_name=server_name,
server=server_address,
timeout_secs=10)
return request_fn
def query_t2t(input_txt, data_dir, problem_name, server_name, server_address, t2t_usr_dir):
usr_dir.import_usr_dir(t2t_usr_dir)
problem = registry.problem(problem_name)
hparams = tf.contrib.training.HParams(
data_dir=os.path.expanduser(data_dir))
problem.get_hparams(hparams)
request_fn = make_request_fn(server_name, server_address)
inputs = input_txt
outputs = serving_utils.predict([inputs], problem, request_fn)
output, score = outputs
return output, score
上面的函數參數,分別是:
輸入內容、數據文件夾,問題名稱,服務名稱,服務的地址,自定義文件夾
其中這些參數都是和上述命令中對應的,端口是默認9000,可以根據需要進行更改,歡迎留言評論討論
如果是用的docker進行部署,注意端口的映射,容器內端口映射到服務端口