小白Bert系列-生成pb模型,tfserving加载,flask进行预测

bert分类模型使用tfserving部署。

bert模型服务化现在已经有对应开源库部署。

例如:1.https://github.com/macanv/BERT-BiLSTM-CRF-NER 该项目支持三种不同的任务

2.使用已有的包
pip install bert-serving-server # 服务端
pip install bert-serving-client # 客户端,与服务端互相独立

本文主要记录通过bert预训练的模型如何使用tfserving进行加载,这样在已有的tfserving model_config服务下增加一个config就可以了。

1.bert 分类模型见这篇文章。

模型实践(二)bert

2.模型输出文件

训练过程种需要注意在该文件run_classifier.py额外需要添加

def convert_single_example(ex_index, example, label_list, max_seq_length,
                           tokenizer):
  """Converts a single `InputExample` into a single `InputFeatures`."""

  if isinstance(example, PaddingInputExample):
    return InputFeatures(
        input_ids=[0] * max_seq_length,
        input_mask=[0] * max_seq_length,
        segment_ids=[0] * max_seq_length,
        label_id=0,
        is_real_example=False)

  label_map = {}
  for (i, label) in enumerate(label_list):
    label_map[label] = i

  #!!!!额外添加 保存标签信息和枚举的映射关系 并写入label2id.pkl 该文件会一起输出到output文件夹下
  output_label2id_file = os.path.join(FLAGS.output_dir, "label2id.pkl")
  if not os.path.exists(output_label2id_file):
    with open(output_label2id_file,'wb') as w:
      pickle.dump(label_map,w)

label2id.pkl其实就是一个映射关系(自定义label和模型使用label),也可以自己序列化写入:

# pickle.load(open("label2id.pkl", 'rb'))
{'1': 0, '2': 1, '3': 2}

实际训练是使用这样的命令:

export DATA_DIR=/data/bert/data
export BERT_BASE_DIR=/data/bert/model/chinese_L-12_H-768_A-12
export OUTPUT_DIR=/data/bert/

python run_classifier.py --task_name=news --do_train=true --do_eval=true --data_dir=$DATA_DIR/ --vocab_file=$BERT_BASE_DIR/vocab.txt --bert_config_file=$BERT_BASE_DIR/bert_config.json --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt --max_seq_length=128 --train_batch_size=32 --learning_rate=2e-5 --num_train_epochs=3.0 --output_dir=$OUTPUT_DIR/output

训练完成后output文件夹大致如下:

drwxr-xr-x 3 root root       4096 May 30 15:58 ./
drwxr-xr-x 7 root root       4096 May 30 16:08 ../
-rw-r--r-- 1 root root        222 May 26 21:58 checkpoint
drwxr-xr-x 2 root root       4096 May 26 21:59 eval/
-rw-r--r-- 1 root root         86 May 26 21:59 eval_results.txt
-rw-r--r-- 1 root root    1687178 May 26 21:58 eval.tf_record
-rw-r--r-- 1 root root   14356236 May 26 21:58 events.out.tfevents.1590500393.instance-py09166k
-rw-r--r-- 1 root root    9528376 May 26 21:40 graph.pbtxt
-rw-r--r-- 1 root root         38 May 26 21:39 label2id.pkl
-rw-r--r-- 1 root root 1227239468 May 26 21:40 model.ckpt-0.data-00000-of-00001
-rw-r--r-- 1 root root      22717 May 26 21:40 model.ckpt-0.index
-rw-r--r-- 1 root root    4046471 May 26 21:40 model.ckpt-0.meta
-rw-r--r-- 1 root root 1227239468 May 26 21:48 model.ckpt-1000.data-00000-of-00001
-rw-r--r-- 1 root root      22717 May 26 21:48 model.ckpt-1000.index
-rw-r--r-- 1 root root    4046471 May 26 21:48 model.ckpt-1000.meta
-rw-r--r-- 1 root root 1227239468 May 26 21:56 model.ckpt-2000.data-00000-of-00001
-rw-r--r-- 1 root root      22717 May 26 21:56 model.ckpt-2000.index
-rw-r--r-- 1 root root    4046471 May 26 21:56 model.ckpt-2000.meta
-rw-r--r-- 1 root root 1227239468 May 26 21:58 model.ckpt-2250.data-00000-of-00001
-rw-r--r-- 1 root root      22717 May 26 21:58 model.ckpt-2250.index
-rw-r--r-- 1 root root    4046471 May 26 21:58 model.ckpt-2250.meta
-rw-r--r-- 1 root root    1615474 May 26 22:28 predict.tf_record
-rw-r--r-- 1 root root     105910 May 26 22:29 test_results.tsv
-rw-r--r-- 1 root root   13543303 May 26 21:39 train.tf_record

来大致看一下几个文件作用:

checkpoint 记录可用的模型信息
eval_results.txt 验证集的结果信息
eval.tf_record 记录验证集的二进制信息
events.out.tfevents.1590500393.instance-py09166k 用于tensorboard查看详细信息
graph.pbtxt 记录tensorflow的结构信息
label2id.pkl 标签信息 (额外加的)
model.ckpt-0* 这里是记录最近的三个文件
model.ckpt-2250.data 所有变量的值
model.ckpt-2250.index 可能是用于映射图和权重关系,0.11版本后引入
model.ckpt-2250.meta 记录完整的计算图结构
predict.tf_record 预测的二进制文件
test_results.tsv 使用预测后生成的预测结果
3.生成pb模型

生成pb模型借助一个中间的文件,创建 freeze_graph.py 放置于bert文件夹下。完整版本见github:https://github.com/hiroLinGoing/bert

运行:

export DATA_DIR=/data/bert/output
export BERT_BASE_DIR=/data/bert/model/chinese_L-12_H-768_A-12

#bertPB 表示模型名。
#pb_version 模型版本号

python freeze_graph.py -bert_model_dir=$BERT_BASE_DIR -model_dir=$DATA_DIR -max_seq_len=128 -num_labels=3 -pb_file_name bertPB -pb_version 100001

运行后会在output文件夹下生成模型文件,如下:

在这里插入图片描述

直接将bertPB文件夹拷贝到tfserving所配置得文件夹下。

4.运行服务

此处包括了多个模型,model_config:

model_config_list:{
    config:{
      name:"cnn",
      base_path:"/models/text_models/cnn",
      model_platform:"tensorflow"
    },
    config:{
      name:"rnn",
      base_path:"/models/text_models/rnn",
      model_platform:"tensorflow"
    },
    config:{
      name:"lstm",
      base_path:"/models/text_models/lstm",
      model_platform:"tensorflow"
    },
    config:{
      name:"bert",
      base_path:"/models/text_models/bertPB",
      model_platform:"tensorflow"
    }
}

运行tfserving docker

docker run -p 8501:8501 --mount type=bind,source=/data/models/text_models/,target=/models/text_models -t tensorflow/serving --model_config_file=/models/text_models/model.config

同时加载model_config_list中的4个模型。

这里简单说一下docker这里的指令:

-p 8501:8501 端口映射 宿主机端口:docker容器端口 即访问docker只需要访问宿主机即可
--mount type=bind,source=xxxx,target=xxxx 将宿主机的source文件夹挂在到docker容器中target文件夹 
mount和-v的区别是 -v宿主机不存在目录时会自动创建 mount会直接报错
--model_config_file 也就是tf serving docker运行时需要加载的文件 注意!!!此处的路径就需要用target挂载之后的路径了 因为这个docker内部参数!!
5.预测

对于bert finetune的模型预测需要注意入参。因为在训练时候bert的输入参数由多个输入组成。

首先使用curl直观看一下:

curl -d '{"signature_name": "serving_default", "instances": [{"label_ids": 0,"input_ids": [101, 1920, 1004, 6873, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "input_mask": [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "segment_ids": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}]}' -X POST  "http://localhost:8501/v1/models/bert:predict"

可以看到输入是由label_ids input_ids input_mask segment_ids四个部分组成。这里其实从训练源代码也可以看出来。

那么咱们根据源码来构造请求。这里使用flask演示(因为tfserving提供了http接口 所以普通post请求就可以)详细见web.py文件https://github.com/hiroLinGoing/bert/tree/master

#这个方法就是google-bert run_classify.py中创建输入的方法
def convert_single_example(max_seq_length,
                           tokenizer, text_a, text_b=None):
    tokens_a = tokenizer.tokenize(text_a)
    tokens_b = None
    if text_b:
    print(......)
def detect_bert(content):
    input_ids, input_mask, segment_ids = convert_single_example(128, token, content)
    features = {}
    #构造具体请求参数
    features["input_ids"] = input_ids
    features["input_mask"] = input_mask
    features["segment_ids"] = segment_ids
    features["label_ids"] = 0
    data_list = []
    data_list.append(features)
    data = json.dumps({"signature_name": "serving_default", "instances": data_list})
    headers = {"content-type": "application/json"}
    json_response = requests.post('http://localhost:8501/v1/models/bert:predict', data=data, headers=headers)
    print(data, json_response.text)
    return jsonify(json.loads(json_response.text))

@app.route('/api', methods=['GET'])
def detect():
    if 'method' not in request.args.keys():
        raise Exception('method is empty.....')
    method = request.args.get("method")
    content = request.args.get("sen")
    if method == 'bert':
        return detect_bert(content)

输出结果:

此处时一个三分类的问题。

{
    "predictions": [[0.0116323624, 0.00175498601, 0.986612678]
    ]
}

至此一个bert分类问题如何生成pb文件,tfserving加载,预测就结束咯…

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章