具體代碼:
https://github.com/danan0755/Bert_Classifier/blob/master/Bert_Predict.py
def request_from_raw_text(vocab_file, label2id_file, query, model_key):
"""
:return:
"""
text_list = [query]
data_list = []
label_list = []
if os.path.exists(label2id_file):
with open(label2id_file, 'rb') as rf:
label2id = pickle.load(rf)
id2label = {value: key for key, value in label2id.items()}
label_list = [key for key in label2id.keys()]
tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=True)
predict_examples = _create_examples(text_list, label2id_file)
for (ex_index, example) in enumerate(predict_examples):
feature = convert_single_example(ex_index, example, label_list, 128,
tokenizer) # ex_index, example, label_list, max_seq_length,tokenizer
features = {}
features["input_ids"] = feature.input_ids
features["input_mask"] = feature.input_mask
features["segment_ids"] = feature.segment_ids
features["label_ids"] = feature.label_id
data_list.append(features)
data = json.dumps({"signature_name": "serving_default", "instances": data_list})
headers = {"content-type": "application/json"}
json_response = requests.post('http://localhost:8501/v1/models/{}:predict'.format(model_key), data=data,
headers=headers)
predictions = json.loads(json_response.text)
p_list = predictions.get('predictions')[0]
label_index = np.argmax(p_list)
label = id2label.get(label_index)
pred_score = max(p_list)
return pred_score, label