【AllenNLP】: 自定義predictor—輸入文本輸出中文

說明

我們在使用AllenNLP的時候,當使用自定義predictor的時候,默認的是輸入json,我們可以修改爲輸入以行爲單位的文本格式;

另外默認的輸出是json,我們也可以自定義修改爲文本,特別是在json.dumps的時候中文會默認是ASCII碼,我們自定義的時候可以設置爲False來輸出中文字符;

另外默認的輸出只有label,沒有input_text作爲參考,我們可以在outputs中新增,來方便地查看預測輸出:

from copy import deepcopy
from typing import List, Dict

from overrides import overrides
import numpy
import json
from allennlp.common.util import JsonDict, sanitize

from allennlp.common.util import JsonDict
from allennlp.data import Instance
from allennlp.predictors.predictor import Predictor
from allennlp.data.fields import LabelField
from allennlp.data.tokenizers import CharacterTokenizer


@Predictor.register("cnews_text_classifier")
class TextClassifierPredictor(Predictor):
    """
    Predictor for any model that takes in a sentence and returns
    a single class for it.  In particular, it can be used with
    the [`BasicClassifier`](../models/basic_classifier.md) model.
    Registered as a `Predictor` with name "text_classifier".
    """
    def __init__(self, model, dataset_reader):
        super(TextClassifierPredictor, self).__init__(model, dataset_reader)
        self.input_text = ""

    def predict(self, sentence: str) -> JsonDict:
        return self.predict_json({"sentence": sentence})

    @overrides
    def load_line(self, line: str) -> JsonDict:
        """
        如果你不想輸入爲json格式,可以可以@overrides這個函數
        """
        return {"text": line}

    @overrides
    def dump_line(self, outputs: JsonDict) -> str:
        """
        如果你不想輸出json格式,可以@overrides這個函數
        """
        return json.dumps(outputs, ensure_ascii=False) + "\n"
    
    @overrides
    def predict_instance(self, instance: Instance) -> JsonDict:
        outputs = self._model.forward_on_instance(instance)
        outputs["input_text"] = self.input_text
        
        return sanitize(outputs)

    @overrides
    def _json_to_instance(self, json_dict: JsonDict) -> Instance:
        """
        Expects JSON that looks like `{"sentence": "..."}`.
        Runs the underlying model, and adds the `"label"` to the output.
        """
        sentence = json_dict["text"]
        self.input_text = json_dict["text"]
        return self._dataset_reader.text_to_instance(sentence)

    @overrides
    def predictions_to_labeled_instances(self, instance: Instance, outputs: Dict[str, numpy.ndarray]) -> List[Instance]:
        new_instance = deepcopy(instance)
        label = numpy.argmax(outputs["probs"])
        new_instance.add_field("label", LabelField(int(label), skip_indexing=True))
        return [new_instance]
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章