小白Bert系列-生成pb模型,tfserving加載,flask進行預測

bert分類模型使用tfserving部署。

bert模型服務化現在已經有對應開源庫部署。

例如:1.https://github.com/macanv/BERT-BiLSTM-CRF-NER 該項目支持三種不同的任務

2.使用已有的包
pip install bert-serving-server # 服務端
pip install bert-serving-client # 客戶端,與服務端互相獨立

本文主要記錄通過bert預訓練的模型如何使用tfserving進行加載,這樣在已有的tfserving model_config服務下增加一個config就可以了。

1.bert 分類模型見這篇文章。

模型實踐(二)bert

2.模型輸出文件

訓練過程種需要注意在該文件run_classifier.py額外需要添加

def convert_single_example(ex_index, example, label_list, max_seq_length,
                           tokenizer):
  """Converts a single `InputExample` into a single `InputFeatures`."""

  if isinstance(example, PaddingInputExample):
    return InputFeatures(
        input_ids=[0] * max_seq_length,
        input_mask=[0] * max_seq_length,
        segment_ids=[0] * max_seq_length,
        label_id=0,
        is_real_example=False)

  label_map = {}
  for (i, label) in enumerate(label_list):
    label_map[label] = i

  #!!!!額外添加 保存標籤信息和枚舉的映射關係 並寫入label2id.pkl 該文件會一起輸出到output文件夾下
  output_label2id_file = os.path.join(FLAGS.output_dir, "label2id.pkl")
  if not os.path.exists(output_label2id_file):
    with open(output_label2id_file,'wb') as w:
      pickle.dump(label_map,w)

label2id.pkl其實就是一個映射關係(自定義label和模型使用label),也可以自己序列化寫入:

# pickle.load(open("label2id.pkl", 'rb'))
{'1': 0, '2': 1, '3': 2}

實際訓練是使用這樣的命令:

export DATA_DIR=/data/bert/data
export BERT_BASE_DIR=/data/bert/model/chinese_L-12_H-768_A-12
export OUTPUT_DIR=/data/bert/

python run_classifier.py --task_name=news --do_train=true --do_eval=true --data_dir=$DATA_DIR/ --vocab_file=$BERT_BASE_DIR/vocab.txt --bert_config_file=$BERT_BASE_DIR/bert_config.json --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt --max_seq_length=128 --train_batch_size=32 --learning_rate=2e-5 --num_train_epochs=3.0 --output_dir=$OUTPUT_DIR/output

訓練完成後output文件夾大致如下:

drwxr-xr-x 3 root root       4096 May 30 15:58 ./
drwxr-xr-x 7 root root       4096 May 30 16:08 ../
-rw-r--r-- 1 root root        222 May 26 21:58 checkpoint
drwxr-xr-x 2 root root       4096 May 26 21:59 eval/
-rw-r--r-- 1 root root         86 May 26 21:59 eval_results.txt
-rw-r--r-- 1 root root    1687178 May 26 21:58 eval.tf_record
-rw-r--r-- 1 root root   14356236 May 26 21:58 events.out.tfevents.1590500393.instance-py09166k
-rw-r--r-- 1 root root    9528376 May 26 21:40 graph.pbtxt
-rw-r--r-- 1 root root         38 May 26 21:39 label2id.pkl
-rw-r--r-- 1 root root 1227239468 May 26 21:40 model.ckpt-0.data-00000-of-00001
-rw-r--r-- 1 root root      22717 May 26 21:40 model.ckpt-0.index
-rw-r--r-- 1 root root    4046471 May 26 21:40 model.ckpt-0.meta
-rw-r--r-- 1 root root 1227239468 May 26 21:48 model.ckpt-1000.data-00000-of-00001
-rw-r--r-- 1 root root      22717 May 26 21:48 model.ckpt-1000.index
-rw-r--r-- 1 root root    4046471 May 26 21:48 model.ckpt-1000.meta
-rw-r--r-- 1 root root 1227239468 May 26 21:56 model.ckpt-2000.data-00000-of-00001
-rw-r--r-- 1 root root      22717 May 26 21:56 model.ckpt-2000.index
-rw-r--r-- 1 root root    4046471 May 26 21:56 model.ckpt-2000.meta
-rw-r--r-- 1 root root 1227239468 May 26 21:58 model.ckpt-2250.data-00000-of-00001
-rw-r--r-- 1 root root      22717 May 26 21:58 model.ckpt-2250.index
-rw-r--r-- 1 root root    4046471 May 26 21:58 model.ckpt-2250.meta
-rw-r--r-- 1 root root    1615474 May 26 22:28 predict.tf_record
-rw-r--r-- 1 root root     105910 May 26 22:29 test_results.tsv
-rw-r--r-- 1 root root   13543303 May 26 21:39 train.tf_record

來大致看一下幾個文件作用:

checkpoint 記錄可用的模型信息
eval_results.txt 驗證集的結果信息
eval.tf_record 記錄驗證集的二進制信息
events.out.tfevents.1590500393.instance-py09166k 用於tensorboard查看詳細信息
graph.pbtxt 記錄tensorflow的結構信息
label2id.pkl 標籤信息 (額外加的)
model.ckpt-0* 這裏是記錄最近的三個文件
model.ckpt-2250.data 所有變量的值
model.ckpt-2250.index 可能是用於映射圖和權重關係,0.11版本後引入
model.ckpt-2250.meta 記錄完整的計算圖結構
predict.tf_record 預測的二進制文件
test_results.tsv 使用預測後生成的預測結果
3.生成pb模型

生成pb模型藉助一箇中間的文件,創建 freeze_graph.py 放置於bert文件夾下。完整版本見github:https://github.com/hiroLinGoing/bert

運行:

export DATA_DIR=/data/bert/output
export BERT_BASE_DIR=/data/bert/model/chinese_L-12_H-768_A-12

#bertPB 表示模型名。
#pb_version 模型版本號

python freeze_graph.py -bert_model_dir=$BERT_BASE_DIR -model_dir=$DATA_DIR -max_seq_len=128 -num_labels=3 -pb_file_name bertPB -pb_version 100001

運行後會在output文件夾下生成模型文件,如下:

在這裏插入圖片描述

直接將bertPB文件夾拷貝到tfserving所配置得文件夾下。

4.運行服務

此處包括了多個模型,model_config:

model_config_list:{
    config:{
      name:"cnn",
      base_path:"/models/text_models/cnn",
      model_platform:"tensorflow"
    },
    config:{
      name:"rnn",
      base_path:"/models/text_models/rnn",
      model_platform:"tensorflow"
    },
    config:{
      name:"lstm",
      base_path:"/models/text_models/lstm",
      model_platform:"tensorflow"
    },
    config:{
      name:"bert",
      base_path:"/models/text_models/bertPB",
      model_platform:"tensorflow"
    }
}

運行tfserving docker

docker run -p 8501:8501 --mount type=bind,source=/data/models/text_models/,target=/models/text_models -t tensorflow/serving --model_config_file=/models/text_models/model.config

同時加載model_config_list中的4個模型。

這裏簡單說一下docker這裏的指令:

-p 8501:8501 端口映射 宿主機端口:docker容器端口 即訪問docker只需要訪問宿主機即可
--mount type=bind,source=xxxx,target=xxxx 將宿主機的source文件夾掛在到docker容器中target文件夾 
mount和-v的區別是 -v宿主機不存在目錄時會自動創建 mount會直接報錯
--model_config_file 也就是tf serving docker運行時需要加載的文件 注意!!!此處的路徑就需要用target掛載之後的路徑了 因爲這個docker內部參數!!
5.預測

對於bert finetune的模型預測需要注意入參。因爲在訓練時候bert的輸入參數由多個輸入組成。

首先使用curl直觀看一下:

curl -d '{"signature_name": "serving_default", "instances": [{"label_ids": 0,"input_ids": [101, 1920, 1004, 6873, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "input_mask": [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "segment_ids": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}]}' -X POST  "http://localhost:8501/v1/models/bert:predict"

可以看到輸入是由label_ids input_ids input_mask segment_ids四個部分組成。這裏其實從訓練源代碼也可以看出來。

那麼咱們根據源碼來構造請求。這裏使用flask演示(因爲tfserving提供了http接口 所以普通post請求就可以)詳細見web.py文件https://github.com/hiroLinGoing/bert/tree/master

#這個方法就是google-bert run_classify.py中創建輸入的方法
def convert_single_example(max_seq_length,
                           tokenizer, text_a, text_b=None):
    tokens_a = tokenizer.tokenize(text_a)
    tokens_b = None
    if text_b:
    print(......)
def detect_bert(content):
    input_ids, input_mask, segment_ids = convert_single_example(128, token, content)
    features = {}
    #構造具體請求參數
    features["input_ids"] = input_ids
    features["input_mask"] = input_mask
    features["segment_ids"] = segment_ids
    features["label_ids"] = 0
    data_list = []
    data_list.append(features)
    data = json.dumps({"signature_name": "serving_default", "instances": data_list})
    headers = {"content-type": "application/json"}
    json_response = requests.post('http://localhost:8501/v1/models/bert:predict', data=data, headers=headers)
    print(data, json_response.text)
    return jsonify(json.loads(json_response.text))

@app.route('/api', methods=['GET'])
def detect():
    if 'method' not in request.args.keys():
        raise Exception('method is empty.....')
    method = request.args.get("method")
    content = request.args.get("sen")
    if method == 'bert':
        return detect_bert(content)

輸出結果:

此處時一個三分類的問題。

{
    "predictions": [[0.0116323624, 0.00175498601, 0.986612678]
    ]
}

至此一個bert分類問題如何生成pb文件,tfserving加載,預測就結束咯…

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章