xlnet--數據預處理

    for split, batch_size in zip(
            ["train", "valid"],
            [FLAGS.per_host_train_bsz, FLAGS.per_host_valid_bsz]):

        if batch_size <= 0: continue
        print("Converting {} set...".format(split))
        corpus.convert_to_tfrecords(split, save_dir, batch_size, FLAGS.tgt_len,
                                    FLAGS.num_core_per_host, FLAGS=FLAGS)

per_host_train_bsz每次訓練所取的大小

record_name 存放記錄的文件名
data在創建corpus時從train.txt中讀取的輸入內容,已經轉換爲向量。

 def convert_to_tfrecords(self, split, save_dir, bsz, tgt_len,
                             num_core_per_host, **kwargs):
        FLAGS = kwargs.get('FLAGS')

        file_names = []
        use_tpu = FLAGS.use_tpu and not (split == "test" and num_core_per_host == 1)

        if use_tpu:
            record_name = "record_info-{}.bsz-{}.tlen-{}.core-{}.json".format(
                split, bsz, tgt_len, num_core_per_host)
        else:
            record_name = "record_info-{}.bsz-{}.tlen-{}.json".format(
                split, bsz, tgt_len)

        record_info_path = os.path.join(save_dir, record_name)

        if self.dataset in ["ptb", "wt2", "wt103", "enwik8", "tangshi", "doupo", "test", "zhihu", "poetry","zhuxian","xuezhongqijie","longzu"]:
            data = getattr(self, split)

            bin_sizes = get_bin_sizes(
                data, bsz // num_core_per_host, tgt_len, self.cutoffs)
            file_name, num_batch = create_ordered_tfrecords(
                save_dir, split, data, bsz, tgt_len, num_core_per_host,
                self.cutoffs, bin_sizes,
                num_passes=FLAGS.num_passes if split == 'train' and use_tpu else 1,
                use_tpu=use_tpu)
            file_names.append(file_name)

        with open(record_info_path, "w") as fp:
            record_info = {
                "filenames": file_names,
                "bin_sizes": bin_sizes,
                "num_batch": num_batch
            }
            json.dump(record_info, fp)

逐行逐列處理,
共bgz行,
inputs 當前行,t, t+tgt_len個元素
labels 當前行,t+1後面的元素,也就是inputs的後一個字的一句話。

將inputs保存在tf.train.Example中,然後序列化,使用record_writer存儲預處理之後的數據。

def create_ordered_tfrecords(save_dir, basename, data, batch_size, tgt_len,
                             num_core_per_host, cutoffs=[], bin_sizes=[],
                             num_passes=1, use_tpu=False):
    # save_dir 就是tfrecord的路徑
    if use_tpu:
        file_name = "{}.bsz-{}.tlen-{}.core-{}.tfrecords".format(
            basename, batch_size, tgt_len, num_core_per_host)
    else:
        file_name = "{}.bsz-{}.tlen-{}.tfrecords".format(
            basename, batch_size, tgt_len)

    save_path = os.path.join(save_dir, file_name)
    record_writer = tf.python_io.TFRecordWriter(save_path)

    batched_data = batchify(data, batch_size, num_passes)

    num_batch = 0
    for t in range(0, batched_data.shape[1] - 1, tgt_len):
    	#當前的tgt_len,如果不夠tgt_len,則取小的
        cur_tgt_len = min(batched_data.shape[1] - 1 - t, tgt_len)
        # drop the remainder if use tpu
        if use_tpu and cur_tgt_len < tgt_len:
            break
        if num_batch % 500 == 0:
            print("  processing batch {}".format(num_batch))
        for idx in range(batch_size):
            inputs = batched_data[idx, t:t + cur_tgt_len]
            labels = batched_data[idx, t + 1:t + cur_tgt_len + 1]

            # features dict
            feature = {
                "inputs": _int64_feature(inputs),
                "labels": _int64_feature(labels),
            }

            if len(cutoffs) > 0 and use_tpu:
                # validate `bin_sizes` and `cutoffs`
                assert len(cutoffs) - len(bin_sizes) == 2, \
                    "len(cutoffs) - len(bin_sizes) != 2"

                # mask for bin 0
                left, right = cutoffs[:2]
                inp_mask = ((inputs >= left) * (inputs < right)).astype(np.float32)
                tgt_mask = ((labels >= left) * (labels < right)).astype(np.float32)

                feature["inp_mask"] = _float_feature(inp_mask)
                feature["tgt_mask"] = _float_feature(tgt_mask)

                # refresh `inp_cnts` and `tgt_cnts` for each TPU core
                if idx % (batch_size // num_core_per_host) == 0:
                    inp_cnts = [0] * len(bin_sizes)
                    tgt_cnts = [0] * len(bin_sizes)

                head_labels = np.copy(labels)
                inp_pos_per_bin, tgt_pos_per_bin = [], []
                for b, (left, right) in enumerate(zip(cutoffs[1:-1], cutoffs[2:])):
                    inp_pos = np.where((inputs >= left) * (inputs < right))[0]
                    tgt_pos = np.where((labels >= left) * (labels < right))[0]
                    inp_pos_per_bin.append(inp_pos)
                    tgt_pos_per_bin.append(tgt_pos)

                    head_labels[tgt_pos] = cutoffs[1] + b

                feature["head_labels"] = _int64_feature(head_labels)

                # permutation feature
                def _add_perm_feature(feature, pos_per_bin, cnts, prefix):
                    for b, pos in enumerate(pos_per_bin):
                        idx_tuple = []
                        for p in pos:
                            if cnts[b] < bin_sizes[b]:
                                idx_tuple.append([p, cnts[b]])
                                cnts[b] += 1
                            else:
                                break

                        n_tup = len(idx_tuple)
                        tup = np.array(idx_tuple).reshape(n_tup * 2)

                        feature["{}_cnt_{}".format(prefix, b)] = _int64_feature([n_tup])
                        feature["{}_tup_{}".format(prefix, b)] = _int64_feature(tup)

                _add_perm_feature(feature, inp_pos_per_bin, inp_cnts, "inp")
                _add_perm_feature(feature, tgt_pos_per_bin, tgt_cnts, "tgt")

            example = tf.train.Example(features=tf.train.Features(feature=feature))
            record_writer.write(example.SerializeToString())

        num_batch += 1

    record_writer.close()
    print("Done writing {}. batches: {}".format(file_name, num_batch))

    return file_name, num_batch
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章