基於bert實現文本多分類任務

代碼已上傳至github https://github.com/danan0755/Bert_Classifier

數據來源cnews,可以通過百度雲下載

鏈接:https://pan.baidu.com/s/1LzTidW_LrdYMokN---Nyag
提取碼:zejw
 

數據格式如下:

 

bert中文預訓練模型下載地址:

鏈接:https://pan.baidu.com/s/14JcQXIBSaWyY7bRWdJW7yg
提取碼:mvtl

 

複製run_classifier.py,命名爲run_cnews_classifier.py。添加自定義的Processor

class MyProcessor(DataProcessor):

    def read_txt(self, data_dir, flag):
        with open(data_dir, 'r', encoding='utf-8') as f:
            lines = f.readlines()
        random.seed(0)
        random.shuffle(lines)
        # 取少量數據做訓練
        if flag == "train":
            lines = lines[0:5000]
        elif flag == "dev":
            lines = lines[0:500]
        elif flag == "test":
            lines = lines[0:100]
        return lines

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self.read_txt(os.path.join(data_dir, "cnews.train.txt"), "train"), "train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self.read_txt(os.path.join(data_dir, "cnews.val.txt"), "dev"), "dev")

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self.read_txt(os.path.join(data_dir, "cnews.test.txt"), "test"), "test")

    def get_labels(self):
        """See base class."""
        return ["體育", "娛樂", "家居", "房產", "教育", "時尚", "時政", "遊戲", "科技", "財經"]

    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, i)
            split_line = line.strip().split("\t")
            text_a = tokenization.convert_to_unicode(split_line[1])
            text_b = None
            if set_type == "test":
                label = "體育"
            else:
                label = tokenization.convert_to_unicode(split_line[0])
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
        return examples

 

main方法裏添加自定義的Processor

def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    processors = {
        "cola": ColaProcessor,
        "mnli": MnliProcessor,
        "mrpc": MrpcProcessor,
        "xnli": XnliProcessor,
        "cnews": MyProcessor
    }

 

訓練運行命令

python run_cnews_classifier.py --task_name=cnews --do_train=true --do_eval=true --do_predict=false --data_dir=cnews --vocab_file=pretrained_model/chinese_L-12_H-768_A-12/vocab.txt --bert_config_file=pretrained_model/chinese_L-12_H-768_A-12/bert_config.json --init_checkpoint=pretrained_model/chinese_L-12_H-768_A-12/bert_model.ckpt --train_batch_size=32 --max_seq_length=128 --output_dir=model

 

運行測試命令

python run_cnews_classifier.py --task_name=cnews --do_train=false --do_eval=false --do_predict=true --data_dir=cnews --vocab_file=pretrained_model/chinese_L-12_H-768_A-12/vocab.txt --bert_config_file=pretrained_model/chinese_L-12_H-768_A-12/bert_config.json --init_checkpoint=model/model.ckpt-468 --max_seq_length=128 --output_dir=result

 

結果
INFO:tensorflow:  eval_accuracy = 0.93386775
INFO:tensorflow:  eval_loss = 0.33081177
INFO:tensorflow:  global_step = 468
INFO:tensorflow:  loss = 0.3427003

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章