TensorFlow 2.1.0 使用 TFRecord 轉存與讀取圖片

前言

當 NLP 玩家遇到一個 CV 圖像分類的任務時,老實的說,我是有點懵逼的。。。

任務目標是,訓練一個層數較少,結構較爲簡單的圖像分類模型,使用其最後一層隱藏層輸出,作爲特徵向量來表徵圖像。

之前都是使用 Keras 較多,於是本次準備藉着這個簡單的任務上手 TensorFlow 2.1 。


數據加載

Python generator 出現的問題

TensorFlow 2.1 自帶的 tf.data.Dataset 處理訓練數據十分好用,並且自帶 shuffle,repeat,和劃分 batch 的方法。可以通過python generator, numpy list, Tensor slices 等數據結構直接構成 Dataset。

我訓練使用的數據是文檔中的插圖,5個類別共 10w 張。

起初我使用的方法是:構造一個 python generator,訓練時,使用 tf 自帶的 tf.io.read_file() 和 tf.image.decode_jpeg() 方法從磁盤中讀取數據,再使用 tf.data.Dataset.from_generator 生成數據集。

但訓練時發現這樣的數據處理有着很大的問題:受制於generator 的讀取數據速度,batch 數據生成的速度跟不上 GPU 的訓練速度,導致 GPU 的利用率只有不到 10%,訓練速度很慢。很慢。。很慢。。。

 

 TFRecord

這時我想到了 TFRecord 。TFRecord 可以將數據轉存爲二進制文件保存,這樣在訓練時讀取數據就不會遇到以上的問題了。

使用 TFRecord 來進行數據處理,首先需要將原始圖片數據轉存爲 TFRecord 格式:

def pares_image(pic):
'''
    圖片預處理,並轉爲字符串
'''
    label = pares_label(pic)
    try:
        img = Image.open(pic)
        img = img.convert('RGB')
        img = img.resize((54,96))
        img_raw = img.tobytes()
    except:
        return None, None
    return img_raw, label


train_data_list = get_file_path(train_data_path)


writer = tf.io.TFRecordWriter('./data/test_data')
for data in tqdm(test_data_list):
    img_raw, label = pares_image(data)
    if (img_raw is not None) and (label != 'not valid'):
        exam = tf.train.Example(
            features = tf.train.Features(
                feature = {
                    'label': tf.train.Feature(int64_list=tf.train.Int64List (value=[int(label_2_idx[label])])),
                    'data' : tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
                }
            )
        )
        writer.write (exam.SerializeToString())
writer.close()  

文件讀取

接下來讀取 TFRecord 文件,加載進 tf.data,Dataset 

train_reader = tf.data.TFRecordDataset('./data/train_data')
test_reader = tf.data.TFRecordDataset('./data/test_data')
valid_reader = tf.data.TFRecordDataset('./data/valid_data')

feature_description = {
    'data' : tf.io.FixedLenFeature([], tf.string),
    'label': tf.io.FixedLenFeature([] , tf.int64, default_value=-1)
}
def _parse_function (exam_proto):
    temp = tf.io.parse_single_example (exam_proto, feature_description)
    img = tf.io.decode_raw(temp['data'], tf.uint8)
    img = tf.reshape(img, [54, 96, 3])
    img = img / 255
    label = temp['label']
    return (img, label)


train_dataset = train_reader.repeat(5).shuffle(12800, reshuffle_each_iteration=True).map(_parse_function).batch(128)
test_dataset = test_reader.repeat(5).shuffle(12800, reshuffle_each_iteration=True).map(_parse_function).batch(128)
valid_dataset = test_reader.repeat(5).shuffle(12800, reshuffle_each_iteration=True).map(_parse_function).batch(128)

讀取文件時,需要重新將二進制的圖像數據重新 decode 爲 Tensor,並進行數值歸一化處理。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章