如何使用tensor2tensor部署一個預測任務

之前有人說怎麼將t2t的訓練模型部署起來,其實不難!
首先,是安裝tensorflow-model-server 可以自行百度!
然後進行下列操作:

這裏假設你已經有了訓練好的t2t模型
模型導出:

t2t-exporter \
        --t2t_usr_dir=$T2T_USR_DIR \
        --model=$MODEL \
        --hparams_set=$HPARAMS \
        --problem=$PROBLEM \
        --data_dir=$DATA_DIR \
        --output_dir=$TRAIN_DIR

模型部署:

tensorflow_model_base_server \
        --port=9000 \
        --model_name=my_model \
        --model_base_path=$TRAIN_DIR/export/Servo

注意這裏的tensor2tensor版本是 1.7.x

然後進行模型的調用,也就是使用模型進行預測:
這裏是根據官方代碼修改而來:只需在請求處理函數中調用下面的這個函數就行!

# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

from oauth2client.client import GoogleCredentials
from six.moves import input  # pylint: disable=redefined-builtin

from tensor2tensor import problems as problems_lib  # pylint: disable=unused-import
from tensor2tensor.serving import serving_utils
from tensor2tensor.utils import registry
from tensor2tensor.utils import usr_dir
import tensorflow as tf

def make_request_fn(server_name, server_address):
    """Returns a request function."""
    request_fn = serving_utils.make_grpc_request_fn(
        servable_name=server_name,
        server=server_address,
        timeout_secs=10)
    return request_fn


def query_t2t(input_txt, data_dir, problem_name, server_name, server_address, t2t_usr_dir):
    usr_dir.import_usr_dir(t2t_usr_dir)
    problem = registry.problem(problem_name)
    hparams = tf.contrib.training.HParams(
        data_dir=os.path.expanduser(data_dir))
    problem.get_hparams(hparams)
    request_fn = make_request_fn(server_name, server_address)
    inputs = input_txt
    outputs = serving_utils.predict([inputs], problem, request_fn)
    output, score = outputs
    return output, score

上面的函數參數,分別是:
輸入內容、數據文件夾,問題名稱,服務名稱,服務的地址,自定義文件夾
其中這些參數都是和上述命令中對應的,端口是默認9000,可以根據需要進行更改,歡迎留言評論討論

如果是用的docker進行部署,注意端口的映射,容器內端口映射到服務端口

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章