利用bert進行文本分類

1、任務及數據集描述
實現利用bert預訓練模型進行中文新聞分類,共10類,使用的數據集情況:
在這裏插入圖片描述
其中,train.txt, dev.txt, test.txt內容格式爲每一行爲“內容 Tab 標籤”:
在這裏插入圖片描述
class.txt內容爲10類的新聞標籤,如上面的0就代表finance這一類。
在這裏插入圖片描述
2、bert模型準備

(1)下載bert中文預訓練模型chinese_L-12_H-768_A-12,解壓后里麪包含5個文件:模型、配置文件與詞典。
(2)去github上下載bert源碼:https://github.com/google-research/bert.git

3、修改源碼實現文本分類
我們只需要將我們的數據輸入處理成標準的結構輸入就可以了,在run_classifier.py文件中,有一個DataProcessor基類:

class DataProcessor(object):
  """Base class for data converters for sequence classification data sets."""

  def get_train_examples(self, data_dir):
    """Gets a collection of `InputExample`s for the train set."""
    raise NotImplementedError()

  def get_dev_examples(self, data_dir):
    """Gets a collection of `InputExample`s for the dev set."""
    raise NotImplementedError()

  def get_test_examples(self, data_dir):
    """Gets a collection of `InputExample`s for prediction."""
    raise NotImplementedError()

  def get_labels(self):
    """Gets the list of labels for this data set."""
    raise NotImplementedError()

  @classmethod
  def _read_tsv(cls, input_file, quotechar=None):
    """Reads a tab separated value file."""
    with tf.gfile.Open(input_file, "r") as f:
      reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
      lines = []
      for line in reader:
        lines.append(line)
      return lines

在這個基類中定義了一個讀取文件的靜態方法_read_tsv,四個分別獲取train set,dev set,test set和lable的方法。下面我們要按照這種形式定義自己的數據處理類,類命名爲MyTaskProcessor。

(1)編寫MyTaskProcessor類,讓MyTaskProcessor繼承DataProcessor,用於定義自己的任務:

class MyTaskProcessor(DataProcessor):
    '''用於自己的任務——news classification'''
    def __init__(self, data_dir):
        # self.labels = ['財經', '房產', '股票', '教育', '科技', '社會', '時政', '體育', '遊戲', '娛樂']
        self.labels = [c.strip() for c in open(os.path.join(data_dir, "class.txt")).readlines()]

    def get_train_examples(self, data_dir):
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "train.txt")), "train")

    def get_dev_examples(self, data_dir):
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "dev.txt")), "val")

    def get_test_examples(self, data_dir):
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "test.txt")), "test")

    def get_labels(self):
        return self.labels

    def _create_examples(self, lines, set_type):
        '''create examples for the training and val sets'''
        examples = []
        for (i, line) in enumerate(lines):
            guid = "%s-%s" % (set_type, i)
            text_a = tokenization.convert_to_unicode(line[0])
            label = tokenization.convert_to_unicode(self.labels[int(line[1])])
            examples.append(InputExample(guid=guid, text_a=text_a, label=label))
        return examples

這裏的_read_tsv()方式是讀取的是每行以Tab分割的數據,如果自己的數據是以其他形式分割的,需要自己重寫一個方法,並在_create_examples()中更改。
這樣就完成了ber輸入的數據形式轉換,接下來就可以進行模型訓練了。

(2)修改main()函數,進行訓練
在這裏插入圖片描述

  ...
  processor = processors[task_name](FLAGS.data_dir)
  ...

最後,給定代碼最開始的輸入參數,包括:“data_dir”, “bert_config_file”, “task_name”, “vocab_file”, “output_dir”, “init_checkpoint”, “do_train”, “do_eval”, “do_predict” 以及一些超參設置。

其中data_dir是你的要訓練的文本的數據所在的文件夾,bert_config_file, init_checkpoint是你的bert預訓練模型存放的地址,task_name要與上面main中的紅框中的名字一致。下面的幾個參數,do_train代表是否進行fine tune,do_eval代表是否進行evaluation,一般設置爲True, 還有未出現的參數do_predict代表是否進行預測。如果不需要進行fine tune,或者顯卡配置太低的話,可以將do_trian去掉。max_seq_length代表了句子的最長長度,當顯存不足時,可以適當降低max_seq_length。

其他:
(1)在訓練時輸出loss
bert源碼中,在run_classifier.py文件中,訓練模型和驗證模型用的都是tensorflow中的estimator接口,因此無法實現在訓練迭代100步就用驗證集驗證一次,在run_classifier.py文件中提供的方法是先運行完所有的epochs之後,再加載模型進行驗證,這種無法在訓練時輸出驗證集上的結果,不能很直觀的看到損失函數的變化,所以也就不能確定模型是否收斂。
要實現在訓練過程中輸出loss日誌,我們可以使用hooks參數實現:

# 訓練
    train_input_fn = file_based_input_fn_builder(
        input_file=train_file,
        seq_length=FLAGS.max_seq_length,
        is_training=True,
        drop_remainder=True)
    # 使用hooks參數在訓練過程中輸出loss日誌
    tensors_to_log = {"train loss": "loss/Mean:0"}
    logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log, every_n_iter=100)
    estimator.train(input_fn=train_input_fn, max_steps=num_train_steps, hooks=[logging_hook])

(2)增加驗證集輸出的指標值,loss, accuracy,auc,recall,precision。

      def metric_fn(per_example_loss, label_ids, logits, is_real_example):
        '''驗證,輸出指標:loss, accuracy, auc,recall,precision'''
        predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
        accuracy = tf.metrics.accuracy(
            labels=label_ids, predictions=predictions, weights=is_real_example)
        loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example)
        auc = tf.metrics.auc(labels=label_ids, predictions=predictions, weights=is_real_example)
        precision = tf.metrics.precision(labels=label_ids, predictions=predictions, weights=is_real_example)
        recall = tf.metrics.recall(labels=label_ids, predictions=predictions, weights=is_real_example)

        return {
            "eval_accuracy": accuracy,
            "eval_loss": loss,
            "eval_auc": auc,
            "eval_precision": precision,
            "eval_recall": recall
        }

以上就是利用google開源的bert模型實現文本分類任務,如有不對之處請指正。

發佈了101 篇原創文章 · 獲贊 57 · 訪問量 12萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章