1、任務及數據集描述
實現利用bert預訓練模型進行中文新聞分類,共10類,使用的數據集情況:
其中,train.txt, dev.txt, test.txt內容格式爲每一行爲“內容 Tab 標籤”:
class.txt內容爲10類的新聞標籤,如上面的0就代表finance這一類。
2、bert模型準備
(1)下載bert中文預訓練模型chinese_L-12_H-768_A-12,解壓后里麪包含5個文件:模型、配置文件與詞典。
(2)去github上下載bert源碼:https://github.com/google-research/bert.git
3、修改源碼實現文本分類
我們只需要將我們的數據輸入處理成標準的結構輸入就可以了,在run_classifier.py文件中,有一個DataProcessor基類:
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
def get_dev_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError()
def get_test_examples(self, data_dir):
"""Gets a collection of `InputExample`s for prediction."""
raise NotImplementedError()
def get_labels(self):
"""Gets the list of labels for this data set."""
raise NotImplementedError()
@classmethod
def _read_tsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file."""
with tf.gfile.Open(input_file, "r") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
lines = []
for line in reader:
lines.append(line)
return lines
在這個基類中定義了一個讀取文件的靜態方法_read_tsv,四個分別獲取train set,dev set,test set和lable的方法。下面我們要按照這種形式定義自己的數據處理類,類命名爲MyTaskProcessor。
(1)編寫MyTaskProcessor類,讓MyTaskProcessor繼承DataProcessor,用於定義自己的任務:
class MyTaskProcessor(DataProcessor):
'''用於自己的任務——news classification'''
def __init__(self, data_dir):
# self.labels = ['財經', '房產', '股票', '教育', '科技', '社會', '時政', '體育', '遊戲', '娛樂']
self.labels = [c.strip() for c in open(os.path.join(data_dir, "class.txt")).readlines()]
def get_train_examples(self, data_dir):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.txt")), "train")
def get_dev_examples(self, data_dir):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.txt")), "val")
def get_test_examples(self, data_dir):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.txt")), "test")
def get_labels(self):
return self.labels
def _create_examples(self, lines, set_type):
'''create examples for the training and val sets'''
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[0])
label = tokenization.convert_to_unicode(self.labels[int(line[1])])
examples.append(InputExample(guid=guid, text_a=text_a, label=label))
return examples
這裏的_read_tsv()方式是讀取的是每行以Tab分割的數據,如果自己的數據是以其他形式分割的,需要自己重寫一個方法,並在_create_examples()中更改。
這樣就完成了ber輸入的數據形式轉換,接下來就可以進行模型訓練了。
(2)修改main()函數,進行訓練
...
processor = processors[task_name](FLAGS.data_dir)
...
最後,給定代碼最開始的輸入參數,包括:“data_dir”, “bert_config_file”, “task_name”, “vocab_file”, “output_dir”, “init_checkpoint”, “do_train”, “do_eval”, “do_predict” 以及一些超參設置。
其中data_dir是你的要訓練的文本的數據所在的文件夾,bert_config_file, init_checkpoint是你的bert預訓練模型存放的地址,task_name要與上面main中的紅框中的名字一致。下面的幾個參數,do_train代表是否進行fine tune,do_eval代表是否進行evaluation,一般設置爲True, 還有未出現的參數do_predict代表是否進行預測。如果不需要進行fine tune,或者顯卡配置太低的話,可以將do_trian去掉。max_seq_length代表了句子的最長長度,當顯存不足時,可以適當降低max_seq_length。
其他:
(1)在訓練時輸出loss
bert源碼中,在run_classifier.py文件中,訓練模型和驗證模型用的都是tensorflow中的estimator接口,因此無法實現在訓練迭代100步就用驗證集驗證一次,在run_classifier.py文件中提供的方法是先運行完所有的epochs之後,再加載模型進行驗證,這種無法在訓練時輸出驗證集上的結果,不能很直觀的看到損失函數的變化,所以也就不能確定模型是否收斂。
要實現在訓練過程中輸出loss日誌,我們可以使用hooks參數實現:
# 訓練
train_input_fn = file_based_input_fn_builder(
input_file=train_file,
seq_length=FLAGS.max_seq_length,
is_training=True,
drop_remainder=True)
# 使用hooks參數在訓練過程中輸出loss日誌
tensors_to_log = {"train loss": "loss/Mean:0"}
logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log, every_n_iter=100)
estimator.train(input_fn=train_input_fn, max_steps=num_train_steps, hooks=[logging_hook])
(2)增加驗證集輸出的指標值,loss, accuracy,auc,recall,precision。
def metric_fn(per_example_loss, label_ids, logits, is_real_example):
'''驗證,輸出指標:loss, accuracy, auc,recall,precision'''
predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
accuracy = tf.metrics.accuracy(
labels=label_ids, predictions=predictions, weights=is_real_example)
loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example)
auc = tf.metrics.auc(labels=label_ids, predictions=predictions, weights=is_real_example)
precision = tf.metrics.precision(labels=label_ids, predictions=predictions, weights=is_real_example)
recall = tf.metrics.recall(labels=label_ids, predictions=predictions, weights=is_real_example)
return {
"eval_accuracy": accuracy,
"eval_loss": loss,
"eval_auc": auc,
"eval_precision": precision,
"eval_recall": recall
}
以上就是利用google開源的bert模型實現文本分類任務,如有不對之處請指正。