背景
使用BERT-TensorFlow解決法研杯要素識別任務
,該任務其實是一個多標籤文本分類任務。模型的具體不是本文重點,故於此不細細展開說明。本文重點闡述如何部署模型。
模型部署
官方推薦TensorFlow模型在生產環境中提供服務時使用SavedModel格式。SavedModel格式是一種通用的、語言中立的、密閉的、可恢復的TensorFlow模型序列化格式。SavedModel封裝了TensorFlow Saver,對於模型服務是一種標準的導出方法。
導出SaveModel格式
這裏的estimator
部分也忽略,不詳細說明,其關鍵是調用estimator的export_savedmodel
導出SaveModel格式的模型,注意serving_input_fn
的編寫。其中的字段與後續POST中的數據字段相對應。
def serving_input_fn():
# 保存模型爲SaveModel格式
# 採用最原始的feature方式,輸入是feature Tensors。
# 如果採用build_parsing_serving_input_receiver_fn,則輸入是tf.Examples
label_ids = tf.placeholder(tf.int32, [None, 20], name='label_ids') # 要素識別任務有20個類別
input_ids = tf.placeholder(tf.int32, [None, cfig.max_seq_length], name='input_ids')
input_mask = tf.placeholder(tf.int32, [None, cfig.max_seq_length], name='input_mask')
segment_ids = tf.placeholder(tf.int32, [None, cfig.max_seq_length], name='segment_ids')
input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
'label_ids': label_ids,
'input_ids': input_ids,
'input_mask': input_mask,
'segment_ids': segment_ids,
})()
return input_fn
if cfig.do_export:
estimator._export_to_tpu = False
estimator.export_savedmodel(cfig.export_dir, serving_input_fn)
生成的SaveModel:
檢查模型:
saved_model_cli show --dir save_model/1 --all
結果如下圖所示:
部署服務
先基於Docker拉取tensorflow/serving
鏡像(PS:這是CPU版)。再基於鏡像,啓動容器:
docker run --rm -t -p 8501:8501 -v /home/liujiepeng/MachineComprehension/CAIL2019/ElementsRecognition/bert_tensorflow_multi_label/save_model:/models/cail_elem --name=tfserving_cail -e MODEL_NAME=cail_elem tensorflow/serving:latest
運行結果:
2019-09-21 03:24:48.782137: I tensorflow_serving/model_servers/server.cc:82] Building single TensorFlow model file config: model_name: cail_elem model_base_path: /models/cail_elem
2019-09-21 03:24:48.782580: I tensorflow_serving/model_servers/server_core.cc:462] Adding/updating models.
2019-09-21 03:24:48.782633: I tensorflow_serving/model_servers/server_core.cc:561] (Re-)adding model: cail_elem
2019-09-21 03:24:48.883257: I tensorflow_serving/core/basic_manager.cc:739] Successfully reserved resources to load servable {name: cail_elem version: 1}
2019-09-21 03:24:48.883351: I tensorflow_serving/core/loader_harness.cc:66] Approving load for servable version {name: cail_elem version: 1}
2019-09-21 03:24:48.883433: I tensorflow_serving/core/loader_harness.cc:74] Loading servable version {name: cail_elem version: 1}
2019-09-21 03:24:48.883530: I external/org_tensorflow/tensorflow/contrib/session_bundle/bundle_shim.cc:363] Attempting to load native SavedModelBundle in bundle-shim from: /models/cail_elem/1
2019-09-21 03:24:48.883581: I external/org_tensorflow/tensorflow/cc/saved_model/reader.cc:31] Reading SavedModel from: /models/cail_elem/1
2019-09-21 03:24:48.917199: I external/org_tensorflow/tensorflow/cc/saved_model/reader.cc:54] Reading meta graph with tags { serve }
2019-09-21 03:24:48.948563: I external/org_tensorflow/tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 AVX512F FMA
2019-09-21 03:24:49.028645: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:202] Restoring SavedModel bundle.
2019-09-21 03:24:49.497106: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:151] Running initialization op on SavedModel bundle at path: /models/cail_elem/1
2019-09-21 03:24:49.543113: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:311] SavedModel load for tags { serve }; Status: success. Took 659522 microseconds.
2019-09-21 03:24:49.543191: I tensorflow_serving/servables/tensorflow/saved_model_warmup.cc:103] No warmup data file found at /models/cail_elem/1/assets.extra/tf_serving_warmup_requests
2019-09-21 03:24:49.543323: I tensorflow_serving/core/loader_harness.cc:86] Successfully loaded servable version {name: cail_elem version: 1}
2019-09-21 03:24:49.549907: I tensorflow_serving/model_servers/server.cc:324] Running gRPC ModelServer at 0.0.0.0:8500 ...
[warn] getaddrinfo: address family for nodename not supported
[evhttp_server.cc : 239] RAW: Entering the event loop ...
2019-09-21 03:24:49.557068: I tensorflow_serving/model_servers/server.cc:344] Exporting HTTP/REST API at:localhost:8501 ...
查看正在運行的容器:docker container ls
請求服務
對原始請求進行封裝,構建符合要求的POST請求:
# -*- coding: utf-8 -*-
# @CreatTime : 2019/9/20 11:46
# @Author : JasonLiu
# @FileName: test_tfserving.py
import requests
import json
import tensorflow as tf
import collections
import pdb
import numpy as np
from bert import tokenization
from utils import create_examples_text_list, convert_single_example
def test_request():
label_ids = 20*[0]
input_ids = 512*[1]
input_mask = 512*[1]
segment_ids = 512*[1]
data_dict_temp = {
'label_ids': label_ids,
'input_ids': input_ids,
'input_mask': input_mask,
'segment_ids': segment_ids,
}
data_list = []
data_list.append(data_dict_temp)
data = json.dumps({"signature_name": "serving_default", "instances": data_list})
headers = {"content-type": "application/json"}
json_response = requests.post('http://localhost:8501/v1/models/cail_elem:predict', data=data, headers=headers)
print(json_response.text)
predictions = json.loads(json_response.text)['predictions']
print(predictions)
def request_from_raw_text():
"""
:return:
"""
BERT_VOCAB = "/home/data1/ftpdata/pretrain_models/bert_tensoflow_version/bert-base-chinese-vocab.txt"
text_list = ["權人宏偉支行及寶成公司共22次向怡天公司催收借款全部本金及利息,均產生訴訟時效中斷的法律效力,本案債權未過訴訟時效期間", # LN8
"2012年11月30日,原債權人工行錦州市分行向保證人錦州鍋爐有限責任公司發出督促履行保證責任通知書,要求其履行保證責任,"
"2004年11月18日,原債權人工行錦州市分行採用國內掛號信函的方式向保證人錦州鍋爐有限責任公司郵寄送達中國工商銀行遼寧省分行督促履行保證責任通知書," # LN4
"錦州市淩河區公證處相關公證人員對此過程進行了公證。",
"權人宏偉支行及寶成公司共22次向怡天公司催收借款全部本金及利息,均產生訴訟時效中斷的法律效力,本案債權未過訴訟時效期間", # LN8
"2012年11月30日,原債權人工行錦州市分行向保證人錦州鍋爐有限責任公司發出督促履行保證責任通知書,要求其履行保證責任,"
"2004年11月18日,原債權人工行錦州市分行採用國內掛號信函的方式向保證人錦州鍋爐有限責任公司郵寄送達中國工商銀行遼寧省分行督促履行保證責任通知書," # LN4
"錦州市淩河區公證處相關公證人員對此過程進行了公證。",
"權人宏偉支行及寶成公司共22次向怡天公司催收借款全部本金及利息,均產生訴訟時效中斷的法律效力,本案債權未過訴訟時效期間", # LN8
"2012年11月30日,原債權人工行錦州市分行向保證人錦州鍋爐有限責任公司發出督促履行保證責任通知書,要求其履行保證責任,"
"2004年11月18日,原債權人工行錦州市分行採用國內掛號信函的方式向保證人錦州鍋爐有限責任公司郵寄送達中國工商銀行遼寧省分行督促履行保證責任通知書," # LN4
"錦州市淩河區公證處相關公證人員對此過程進行了公證。",
"權人宏偉支行及寶成公司共22次向怡天公司催收借款全部本金及利息,均產生訴訟時效中斷的法律效力,本案債權未過訴訟時效期間", # LN8
"2012年11月30日,原債權人工行錦州市分行向保證人錦州鍋爐有限責任公司發出督促履行保證責任通知書,要求其履行保證責任,"
"2004年11月18日,原債權人工行錦州市分行採用國內掛號信函的方式向保證人錦州鍋爐有限責任公司郵寄送達中國工商銀行遼寧省分行督促履行保證責任通知書," # LN4
"錦州市淩河區公證處相關公證人員對此過程進行了公證。",
"權人宏偉支行及寶成公司共22次向怡天公司催收借款全部本金及利息,均產生訴訟時效中斷的法律效力,本案債權未過訴訟時效期間", # LN8
"2012年11月30日,原債權人工行錦州市分行向保證人錦州鍋爐有限責任公司發出督促履行保證責任通知書,要求其履行保證責任,"
"2004年11月18日,原債權人工行錦州市分行採用國內掛號信函的方式向保證人錦州鍋爐有限責任公司郵寄送達中國工商銀行遼寧省分行督促履行保證責任通知書," # LN4
"錦州市淩河區公證處相關公證人員對此過程進行了公證。",
"權人宏偉支行及寶成公司共22次向怡天公司催收借款全部本金及利息,均產生訴訟時效中斷的法律效力,本案債權未過訴訟時效期間", # LN8
"2012年11月30日,原債權人工行錦州市分行向保證人錦州鍋爐有限責任公司發出督促履行保證責任通知書,要求其履行保證責任,"
"2004年11月18日,原債權人工行錦州市分行採用國內掛號信函的方式向保證人錦州鍋爐有限責任公司郵寄送達中國工商銀行遼寧省分行督促履行保證責任通知書," # LN4
"錦州市淩河區公證處相關公證人員對此過程進行了公證。",
"權人宏偉支行及寶成公司共22次向怡天公司催收借款全部本金及利息,均產生訴訟時效中斷的法律效力,本案債權未過訴訟時效期間", # LN8
"2012年11月30日,原債權人工行錦州市分行向保證人錦州鍋爐有限責任公司發出督促履行保證責任通知書,要求其履行保證責任,"
"2004年11月18日,原債權人工行錦州市分行採用國內掛號信函的方式向保證人錦州鍋爐有限責任公司郵寄送達中國工商銀行遼寧省分行督促履行保證責任通知書," # LN4
"錦州市淩河區公證處相關公證人員對此過程進行了公證。",
"權人宏偉支行及寶成公司共22次向怡天公司催收借款全部本金及利息,均產生訴訟時效中斷的法律效力,本案債權未過訴訟時效期間", # LN8
"2012年11月30日,原債權人工行錦州市分行向保證人錦州鍋爐有限責任公司發出督促履行保證責任通知書,要求其履行保證責任,"
"2004年11月18日,原債權人工行錦州市分行採用國內掛號信函的方式向保證人錦州鍋爐有限責任公司郵寄送達中國工商銀行遼寧省分行督促履行保證責任通知書," # LN4
"錦州市淩河區公證處相關公證人員對此過程進行了公證。",
"權人宏偉支行及寶成公司共22次向怡天公司催收借款全部本金及利息,均產生訴訟時效中斷的法律效力,本案債權未過訴訟時效期間", # LN8
"2012年11月30日,原債權人工行錦州市分行向保證人錦州鍋爐有限責任公司發出督促履行保證責任通知書,要求其履行保證責任,"
"2004年11月18日,原債權人工行錦州市分行採用國內掛號信函的方式向保證人錦州鍋爐有限責任公司郵寄送達中國工商銀行遼寧省分行督促履行保證責任通知書," # LN4
"錦州市淩河區公證處相關公證人員對此過程進行了公證。",
"權人宏偉支行及寶成公司共22次向怡天公司催收借款全部本金及利息,均產生訴訟時效中斷的法律效力,本案債權未過訴訟時效期間", # LN8
"2012年11月30日,原債權人工行錦州市分行向保證人錦州鍋爐有限責任公司發出督促履行保證責任通知書,要求其履行保證責任,"
"2004年11月18日,原債權人工行錦州市分行採用國內掛號信函的方式向保證人錦州鍋爐有限責任公司郵寄送達中國工商銀行遼寧省分行督促履行保證責任通知書," # LN4
"錦州市淩河區公證處相關公證人員對此過程進行了公證。",
"權人宏偉支行及寶成公司共22次向怡天公司催收借款全部本金及利息,均產生訴訟時效中斷的法律效力,本案債權未過訴訟時效期間", # LN8
"2012年11月30日,原債權人工行錦州市分行向保證人錦州鍋爐有限責任公司發出督促履行保證責任通知書,要求其履行保證責任,"
"2004年11月18日,原債權人工行錦州市分行採用國內掛號信函的方式向保證人錦州鍋爐有限責任公司郵寄送達中國工商銀行遼寧省分行督促履行保證責任通知書," # LN4
"錦州市淩河區公證處相關公證人員對此過程進行了公證。",
"權人宏偉支行及寶成公司共22次向怡天公司催收借款全部本金及利息,均產生訴訟時效中斷的法律效力,本案債權未過訴訟時效期間", # LN8
"2012年11月30日,原債權人工行錦州市分行向保證人錦州鍋爐有限責任公司發出督促履行保證責任通知書,要求其履行保證責任,"
"2004年11月18日,原債權人工行錦州市分行採用國內掛號信函的方式向保證人錦州鍋爐有限責任公司郵寄送達中國工商銀行遼寧省分行督促履行保證責任通知書," # LN4
"錦州市淩河區公證處相關公證人員對此過程進行了公證。",
"權人宏偉支行及寶成公司共22次向怡天公司催收借款全部本金及利息,均產生訴訟時效中斷的法律效力,本案債權未過訴訟時效期間", # LN8
"2012年11月30日,原債權人工行錦州市分行向保證人錦州鍋爐有限責任公司發出督促履行保證責任通知書,要求其履行保證責任,"
"2004年11月18日,原債權人工行錦州市分行採用國內掛號信函的方式向保證人錦州鍋爐有限責任公司郵寄送達中國工商銀行遼寧省分行督促履行保證責任通知書," # LN4
"錦州市淩河區公證處相關公證人員對此過程進行了公證。",
"權人宏偉支行及寶成公司共22次向怡天公司催收借款全部本金及利息,均產生訴訟時效中斷的法律效力,本案債權未過訴訟時效期間", # LN8
"2012年11月30日,原債權人工行錦州市分行向保證人錦州鍋爐有限責任公司發出督促履行保證責任通知書,要求其履行保證責任,"
"2004年11月18日,原債權人工行錦州市分行採用國內掛號信函的方式向保證人錦州鍋爐有限責任公司郵寄送達中國工商銀行遼寧省分行督促履行保證責任通知書," # LN4
"錦州市淩河區公證處相關公證人員對此過程進行了公證。",
"權人宏偉支行及寶成公司共22次向怡天公司催收借款全部本金及利息,均產生訴訟時效中斷的法律效力,本案債權未過訴訟時效期間", # LN8
"2012年11月30日,原債權人工行錦州市分行向保證人錦州鍋爐有限責任公司發出督促履行保證責任通知書,要求其履行保證責任,"
"2004年11月18日,原債權人工行錦州市分行採用國內掛號信函的方式向保證人錦州鍋爐有限責任公司郵寄送達中國工商銀行遼寧省分行督促履行保證責任通知書," # LN4
"錦州市淩河區公證處相關公證人員對此過程進行了公證。",
"權人宏偉支行及寶成公司共22次向怡天公司催收借款全部本金及利息,均產生訴訟時效中斷的法律效力,本案債權未過訴訟時效期間", # LN8
"2012年11月30日,原債權人工行錦州市分行向保證人錦州鍋爐有限責任公司發出督促履行保證責任通知書,要求其履行保證責任,"
"2004年11月18日,原債權人工行錦州市分行採用國內掛號信函的方式向保證人錦州鍋爐有限責任公司郵寄送達中國工商銀行遼寧省分行督促履行保證責任通知書," # LN4
"錦州市淩河區公證處相關公證人員對此過程進行了公證。"
]
data_list = []
tokenizer = tokenization.FullTokenizer(vocab_file=BERT_VOCAB, do_lower_case=True)
predict_examples = create_examples_text_list(text_list)
for (ex_index, example) in enumerate(predict_examples):
feature = convert_single_example(ex_index, example,
512, tokenizer)
def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f
features = {}
features["input_ids"] = feature.input_ids
features["input_mask"] = feature.input_mask
# pdb.set_trace()
features["segment_ids"] = feature.segment_ids
if isinstance(feature.label_ids, list):
label_ids = feature.label_ids
else:
label_ids = feature.label_ids[0]
features["label_ids"] = label_ids
# tf_example = tf.train.Example(features=tf.train.Features(feature=features))
data_list.append(features)
data = json.dumps({"signature_name": "serving_default", "instances": data_list})
headers = {"content-type": "application/json"}
json_response = requests.post('http://localhost:8501/v1/models/cail_elem:predict', data=data, headers=headers)
# print(json_response.text)
# pdb.set_trace()
predictions = json.loads(json_response.text)['predictions']
# print(predictions)
for p in range(len(predictions)):
p_list = predictions[p]
label_index = np.argmax(p_list)
print("content={},label={}".format(text_list[p], label_index+1))
print("total number=", len(text_list))
request_from_raw_text()
從運行效率來看,CPU推理上,整體偏慢。運行上述32條任務,耗時:
real 0m17.366s
user 0m1.815s
sys 0m0.997s
那麼我們試試採用tensorflow/serving:latest:gpu
版。此時,我們需要特別注意的是,本地NVIDIA 顯卡驅動和ensorflow/serving:gpu
版本的匹配問題。
由於機器cuda版本是9.0,而tensorflow/serving:latest-gpu
是對應cuda 10版本。所以,需要從https://hub.docker.com/r/tensorflow/serving/tags/
找到合適的gpu版本。最終發現tensorflow/serving:1.12.3-gpu是可以與機器適配的。所以,拉取該鏡像:docker pull tensorflow/serving:1.12.3-gpu
運行容器:
nvidia-docker run -t --rm -p 8501:8501 -v /home/liujiepeng/MachineComprehension/CAIL2019/ElementsRecognition/bert_tensorflow_multi_label/save_model:/models/cail_elem -e MODEL_NAME=cail_elem tensorflow/serving:1.12.3-gpu
即可GPU方式啓動服務。
再測試,發現運行32條任務的耗時如下:
real 0m5.574s
user 0m2.084s
sys 0m0.902s
提速明顯。