xlnet中文文本分類任務

xlnet中文版本預訓練模型終於出來了,見地址https://github.com/ymcui/Chinese-PreTrained-XLNet ,出來之後嘗試了下中文文本分類模型,xlnet模型相比bert有很多東西做了改變,模型層面的不多說,目前放出來的中文文本分類模型是採用24層的網絡結果,和中文版的bert12層的網絡大了兩倍,之前論文出來時候有很多,主要是中文數據處理的問題,模型採用的sentencepiece做分詞,pad方式採用的是post-padding方式,模型輸入輸入是len*batch的形式,還有一些segment_ids和mask和普通的模型並不一樣,下面直接看代碼把,

 

數據轉化爲tfrecord:

import  tensorflow  as tf
import sys
import six
import unicodedata
import sentencepiece as spm
import collections
from textclass import   FLAGS


SEG_ID_A   = 0
SEG_ID_B   = 1
SEG_ID_CLS = 2
SEG_ID_SEP = 3
SEG_ID_PAD = 4

special_symbols = {
    "<unk>"  : 0,
    "<s>"    : 1,
    "</s>"   : 2,
    "<cls>"  : 3,
    "<sep>"  : 4,
    "<pad>"  : 5,
    "<mask>" : 6,
    "<eod>"  : 7,
    "<eop>"  : 8,
}

VOCAB_SIZE = 32000
UNK_ID = special_symbols["<unk>"]
CLS_ID = special_symbols["<cls>"]
SEP_ID = special_symbols["<sep>"]
MASK_ID = special_symbols["<mask>"]
EOD_ID = special_symbols["<eod>"]


sp = spm.SentencePieceProcessor()
sp.Load(FLAGS.spiece_model_file)

def _truncate_seq_pair(tokens_a, tokens_b, max_length):
  while True:
    total_length = len(tokens_a) + len(tokens_b)
    if total_length <= max_length:
      break
    if len(tokens_a) > len(tokens_b):
      tokens_a.pop()
    else:
      tokens_b.pop()

def get_class_ids(text,max_seq_length,tokenize_fn):
  texts = tokenize_fn(text)
  if len(texts) > max_seq_length - 2:
    texts = texts[:max_seq_length - 2]
  tokens = []
  segment_ids = []
  for token in texts:
    tokens.append(token)
    segment_ids.append(SEG_ID_A)
  tokens.append(SEP_ID)
  segment_ids.append(SEG_ID_A)

  tokens.append(CLS_ID)
  segment_ids.append(SEG_ID_CLS)

  input_ids = tokens
  input_mask = [0] * len(input_ids)
  if len(input_ids) < max_seq_length:
    delta_len = max_seq_length - len(input_ids)
    input_ids = [0] * delta_len + input_ids
    input_mask = [1] * delta_len + input_mask
    segment_ids = [SEG_ID_PAD] * delta_len + segment_ids

  assert len(input_ids) == max_seq_length
  assert len(input_mask) == max_seq_length
  assert len(segment_ids) == max_seq_length

  return   input_ids,input_mask,segment_ids


def get_pair_ids(text_a,text_b,max_seq_length,tokenize_fn):
  tokens_a = tokenize_fn(text_a)
  tokens_b = tokenize_fn(text_b)
  _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)

  tokens = []
  segment_ids = []
  for token in tokens_a:
    tokens.append(token)
    segment_ids.append(SEG_ID_A)
  tokens.append(SEP_ID)
  segment_ids.append(SEG_ID_A)

  for token in tokens_b:
    tokens.append(token)
    segment_ids.append(SEG_ID_B)
  tokens.append(SEP_ID)
  segment_ids.append(SEG_ID_B)

  tokens.append(CLS_ID)
  segment_ids.append(SEG_ID_CLS)

  input_ids = tokens
  input_mask = [0] * len(input_ids)

  if len(input_ids) < max_seq_length:
    delta_len = max_seq_length - len(input_ids)
    input_ids = [0] * delta_len + input_ids
    input_mask = [1] * delta_len + input_mask
    segment_ids = [SEG_ID_PAD] * delta_len + segment_ids

  assert len(input_ids) == max_seq_length
  assert len(input_mask) == max_seq_length
  assert len(segment_ids) == max_seq_length


  return input_ids,input_mask,segment_ids



SPIECE_UNDERLINE = '▁'
def encode_pieces(sp_model, text, return_unicode=True, sample=False):
  if six.PY2 and isinstance(text, unicode):
    text = text.encode('utf-8')

  if not sample:
    pieces = sp_model.EncodeAsPieces(text)
  else:
    pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1)
  new_pieces = []
  for piece in pieces:
    if len(piece) > 1 and piece[-1] == ',' and piece[-2].isdigit():
      cur_pieces = sp_model.EncodeAsPieces(
          piece[:-1].replace(SPIECE_UNDERLINE, ''))
      if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
        if len(cur_pieces[0]) == 1:
          cur_pieces = cur_pieces[1:]
        else:
          cur_pieces[0] = cur_pieces[0][1:]
      cur_pieces.append(piece[-1])
      new_pieces.extend(cur_pieces)
    else:
      new_pieces.append(piece)

  # note(zhiliny): convert back to unicode for py2
  if six.PY2 and return_unicode:
    ret_pieces = []
    for piece in new_pieces:
      if isinstance(piece, str):
        piece = piece.decode('utf-8')
      ret_pieces.append(piece)
    new_pieces = ret_pieces

  return new_pieces



def encode_ids(sp_model, text, sample=False):
  pieces = encode_pieces(sp_model, text, return_unicode=False, sample=sample)
  ids = [sp_model.PieceToId(piece) for piece in pieces]
  return ids

def preprocess_text(inputs, lower=False, remove_space=True, keep_accents=False):
  if remove_space:
    outputs = ' '.join(inputs.strip().split())
  else:
    outputs = inputs
  outputs = outputs.replace("``", '"').replace("''", '"')

  if six.PY2 and isinstance(outputs, str):
    outputs = outputs.decode('utf-8')

  if not keep_accents:
    outputs = unicodedata.normalize('NFKD', outputs)
    outputs = ''.join([c for c in outputs if not unicodedata.combining(c)])
  if lower:
    outputs = outputs.lower()

  return outputs


def tokenize_fn(text):
    text = preprocess_text(text, lower=True)
    return encode_ids(sp, text)


def get_vocab(path):
    maps = collections.defaultdict()
    i = 0
    with tf.gfile.GFile(path, "r") as  f:
        for line in f.readlines():
            maps[line.strip()] = i
            i = i + 1
    f.close()
    return maps


def writedataclass(inputpath, vocab, outputpath,max_seq_length,tokenize_fn):
    eachonum = 5000
    num = 0
    recordfilenum = 0
    ftrecordfilename = ("xlnetreading.tfrecords-%.3d" % recordfilenum)
    writer = tf.python_io.TFRecordWriter(outputpath + ftrecordfilename)
    with  open(inputpath)  as f:
        for text in f.readlines():
            texts = text.split("\t")
            content= texts[0].lower().strip()
            label = vocab.get(texts[1].strip())
            num = num + 1
            input_ids,input_mask,segment_ids=get_class_ids(content, max_seq_length, tokenize_fn)
            if num > eachonum:
                num = 1
                recordfilenum = recordfilenum + 1
                ftrecordfilename = ("xlnetreading.tfrecords-%.3d" % recordfilenum)
                writer = tf.python_io.TFRecordWriter(outputpath + ftrecordfilename)

            example = tf.train.Example(
                features=tf.train.Features(
                    feature={'input_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=input_ids)),
                             'input_mask': tf.train.Feature(int64_list=tf.train.Int64List(value=input_mask)),
                             'segment_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=segment_ids)),
                             'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
                             }))
            serialized = example.SerializeToString()
            writer.write(serialized)
    writer.close()
    f.close()

自己寫了一個文本分類的類,看下:

class  XlnetReadingClass(object):
    def __init__(self,model_config_path,is_training,FLAGS,input_ids,segment_ids,
                 input_mask,label,n_class):
        self.xlnet_config = xlnet.XLNetConfig(json_path=model_config_path)
        self.run_config = xlnet.create_run_config(is_training, True, FLAGS)
        self.input_ids=tf.transpose(input_ids,[1,0])
        self.segment_ids = tf.transpose(segment_ids, [1, 0])
        self.input_mask = tf.transpose(input_mask, [1, 0])

        self.model = xlnet.XLNetModel(
            xlnet_config=self.xlnet_config,
            run_config=self.run_config,
            input_ids=self.input_ids,
            seg_ids=self.segment_ids,
            input_mask=self.input_mask)

        cls_scope = FLAGS.cls_scope
        summary = self.model.get_pooled_out(FLAGS.summary_type, FLAGS.use_summ_proj)
        self.per_example_loss, self.logits = modeling.classification_loss(
            hidden=summary,
            labels=label,
            n_class=n_class,
            initializer=self.model.get_initializer(),
            scope=cls_scope,
            return_logits=True)

        self.total_loss = tf.reduce_mean(self.per_example_loss)

        with tf.name_scope("train_op"):

            self.train_op, _, _ = model_utils.get_train_op(FLAGS, self.total_loss)

        with tf.name_scope("acc"):
            one_hot_target = tf.one_hot(label, n_class)
            self.acc=self.accuracy(self.logits,one_hot_target)

    def accuracy(self,logits, labels):
        arglabels_ = tf.argmax(tf.nn.softmax(logits), 1)
        arglabels = tf.argmax(tf.squeeze(labels), 1)
        acc = tf.to_float(tf.equal(arglabels_, arglabels))
        return tf.reduce_mean(acc)


def main(_):
    print('Loading config...')

    n_class = 38

    input_path = FLAGS.data_dir + "xlnetreading.tfrecords*"

    print("input_path:", input_path)
    files = tf.train.match_filenames_once(input_path)

    """
      inputs是你數據的輸入路徑

    """
    input_ids, input_mask, segment_ids, label_ids = inputs(files, batch_size=FLAGS.batch_size, num_epochs=5,max_seq_length=FLAGS.max_seq_length)
    model_config_path=FLAGS.model_config_path
    is_training=False
    init_checkpoint = FLAGS.init_checkpoint


    model = XlnetReadingClass(model_config_path, is_training,FLAGS, input_ids
                    , segment_ids,input_mask, label_ids, n_class)

    tvars = tf.trainable_variables()

    if init_checkpoint:
        (assignment_map, initialized_variable_names) = model_utils.get_assignment_map_from_checkpoint(tvars,

                                                                                                   init_checkpoint)
        tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
        print("restore sucess  on cpu or gpu")

    session = tf.Session()
    session.run(tf.global_variables_initializer())
    session.run(tf.local_variables_initializer())

    print("**** Trainable Variables ****")
    for var in tvars:
        if var.name in initialized_variable_names:
            init_string = ", *INIT_FROM_CKPT*"
            print("name ={0}, shape = {1}{2}".format(var.name, var.shape,
                                                     init_string))

    print("xlnet reading class  model will start train .........")

    print(session.run(files))
    saver = tf.train.Saver()
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord, sess=session)
    start_time = time.time()
    for i in range(8000):
        _, loss_train, acc = session.run([model.train_op, model.total_loss, model.acc])
        if i % 100 == 0:
            end_time = time.time()
            time_dif = end_time - start_time
            time_dif = timedelta(seconds=int(round(time_dif)))
            msg = 'Iter: {0:>6}, Train Loss: {1:>6.2},' \
                  + '  Cost: {2}  Time:{3}  acc:{4}'
            print(msg.format(i, loss_train, time_dif, datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), acc))
            start_time = time.time()
        if i % 500 == 0 and i > 0:
            saver.save(session, "../exp/reading/model.ckpt", global_step=i)
    coord.request_stop()
    coord.join(threads)
    session.close()

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章