requests調用tf serving中的bert模型

具體代碼:
https://github.com/danan0755/Bert_Classifier/blob/master/Bert_Predict.py

def request_from_raw_text(vocab_file, label2id_file, query, model_key):
    """

    :return:
    """
    text_list = [query]
    data_list = []
    label_list = []
    if os.path.exists(label2id_file):
        with open(label2id_file, 'rb') as rf:
            label2id = pickle.load(rf)
            id2label = {value: key for key, value in label2id.items()}
            label_list = [key for key in label2id.keys()]

    tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=True)

    predict_examples = _create_examples(text_list, label2id_file)
    for (ex_index, example) in enumerate(predict_examples):
        feature = convert_single_example(ex_index, example, label_list, 128,
                                         tokenizer)  # ex_index, example, label_list, max_seq_length,tokenizer

        features = {}
        features["input_ids"] = feature.input_ids
        features["input_mask"] = feature.input_mask
        features["segment_ids"] = feature.segment_ids

        features["label_ids"] = feature.label_id

        data_list.append(features)

    data = json.dumps({"signature_name": "serving_default", "instances": data_list})
    headers = {"content-type": "application/json"}
    json_response = requests.post('http://localhost:8501/v1/models/{}:predict'.format(model_key), data=data,
                                  headers=headers)
    predictions = json.loads(json_response.text)
    p_list = predictions.get('predictions')[0]
    label_index = np.argmax(p_list)
    label = id2label.get(label_index)
    pred_score = max(p_list)
    return pred_score, label
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章