說明
我們在使用AllenNLP的時候,當使用自定義predictor
的時候,默認的是輸入json
,我們可以修改爲輸入以行爲單位的文本格式;
另外默認的輸出是json
,我們也可以自定義修改爲文本,特別是在json.dumps
的時候中文會默認是ASCII
碼,我們自定義的時候可以設置爲False
來輸出中文字符;
另外默認的輸出只有label,沒有input_text
作爲參考,我們可以在outputs
中新增,來方便地查看預測輸出:
from copy import deepcopy
from typing import List, Dict
from overrides import overrides
import numpy
import json
from allennlp.common.util import JsonDict, sanitize
from allennlp.common.util import JsonDict
from allennlp.data import Instance
from allennlp.predictors.predictor import Predictor
from allennlp.data.fields import LabelField
from allennlp.data.tokenizers import CharacterTokenizer
@Predictor.register("cnews_text_classifier")
class TextClassifierPredictor(Predictor):
"""
Predictor for any model that takes in a sentence and returns
a single class for it. In particular, it can be used with
the [`BasicClassifier`](../models/basic_classifier.md) model.
Registered as a `Predictor` with name "text_classifier".
"""
def __init__(self, model, dataset_reader):
super(TextClassifierPredictor, self).__init__(model, dataset_reader)
self.input_text = ""
def predict(self, sentence: str) -> JsonDict:
return self.predict_json({"sentence": sentence})
@overrides
def load_line(self, line: str) -> JsonDict:
"""
如果你不想輸入爲json格式,可以可以@overrides這個函數
"""
return {"text": line}
@overrides
def dump_line(self, outputs: JsonDict) -> str:
"""
如果你不想輸出json格式,可以@overrides這個函數
"""
return json.dumps(outputs, ensure_ascii=False) + "\n"
@overrides
def predict_instance(self, instance: Instance) -> JsonDict:
outputs = self._model.forward_on_instance(instance)
outputs["input_text"] = self.input_text
return sanitize(outputs)
@overrides
def _json_to_instance(self, json_dict: JsonDict) -> Instance:
"""
Expects JSON that looks like `{"sentence": "..."}`.
Runs the underlying model, and adds the `"label"` to the output.
"""
sentence = json_dict["text"]
self.input_text = json_dict["text"]
return self._dataset_reader.text_to_instance(sentence)
@overrides
def predictions_to_labeled_instances(self, instance: Instance, outputs: Dict[str, numpy.ndarray]) -> List[Instance]:
new_instance = deepcopy(instance)
label = numpy.argmax(outputs["probs"])
new_instance.add_field("label", LabelField(int(label), skip_indexing=True))
return [new_instance]