TFRecord 的創建與讀取

目錄

1. 概述

2. 創建 TFRecord

2.1 準備工作

2.2 圖片數據集轉換爲 TFRecord 格式

2.3 圖片地址集轉換爲 TFRecord 格式

2.4 小結

3. 讀取 TFRecord 文件


1. 概述

爲了訓練一個深度神經網絡,通常會在本地保存一些數據,這些數據分別被保存在 train,validation,test 文件夾中。這些文件夾散列地保存了成千上萬的圖片或文本文件,無論是直接讀取圖片還是通過 csv 文件保存圖片路徑的間接讀取方式都會帶來兩個問題:

  • 讀取速度慢:讀取圖片內容時需要逐個讀取本地磁盤中的散列文件

  • 內存空間:許多大型數據(如ImageNet)無法一次性加載到內存中

TensorFlow 提供一種統一的格式來存儲數據,即 TFRecord,允許將任意的數據轉換爲TensorFlow所支持的格式。TFRecord 內部使用了“Protocol Buffer”二進制數據編碼方案,只佔用一個內存塊,只需要一次性加載一個二進制文件的方式即可,簡單,快速,尤其對大型訓練數據很友好。因此,在進行深度模型訓練時,建議將訓練樣本存儲爲 TFRecord 格式,當訓練數據量比較大的時候,可以將數據分成多個TFRecord文件,提高處理效率。

TFRecord 文件中的數據通過 tf.train.Example Protocol Buffer 的格式存儲,tf.train.Example 的定義如下:

message Example {
    Features features = 1;
};
message Features {
    map<string, Feature> feature = 1;
};
message feature {
    oneof kind {
        BytesList bytes_list = 1;
        FloatList float_list = 2;
        Int64List int64_list = 3;
    }
};

從定義看,tf.train.Example 的數據結構較爲簡潔,包含了一組 Features(一個從屬性名到取值的字典) ,其中屬性名是 string 類型,屬性的取值 feature 是可以是 BytesList (字符串)、FloatList (實數列表) 或者 Int64List (整數列表)。

2. 創建 TFRecord

爲本地訓練數據集創建 TFRecord 格式的訓練數據集有兩種方式:

  • 對於小規模數據集,將數據數據集集整體轉換成 TFRecord 格式

  • 對於大規模數據集,將圖片數據集對應的地址集轉換爲 TFRecord 格式。

本節以貓狗數據集爲例,部分訓練樣本如下圖,使用以上兩種方法分別創建 tfrecords 文件。

2.1 準備工作

將貓狗數據集中所有的訓練樣本的路徑及其對應的標籤寫入到本地文件中。

def get_data(data_dir):
    '''
    :description: 獲取數據集中所有訓練樣本的路徑和標籤
    :param data_dir: 數據集存放目錄
    :return: 返回訓練樣本的路徑列表和標籤列表
    '''
    images_path = []
    images_label = []

    for root, _, names in os.walk(data_dir):
        for name in names:
            images_path.append(os.path.join(root, name))
            if name.split('.')[0] == 'cat':
                images_label.append(0)
            if name.split('.')[0] == 'dog':
                images_label.append(1)
    return images_path, images_label


def save_data_path(data_dir, save_path):
    '''
    :description: 保存訓練樣本的路徑和對應標籤到指定文件中
    :param data_dir: 訓練樣本所在目錄
    :param save_path: 指定文件保存路徑
    :return: 無返回值
    '''
    if os.path.exists(save_path):
        os.system("rm {}".format(save_path))
    os.system("touch {}".format(save_path))

    images_path, images_label = get_data(data_dir)
    temp = np.array([images_path, images_label])
    temp = temp.transpose()
    np.random.shuffle(temp)
    images = list(temp[:, 0])
    labels = list(temp[:, 1])
    with open(save_path, 'w') as f:
        for i in range(len(images)):
            content = '{} {} \n'.format(images[i], labels[i])
            f.write(content)
            
if __name__ == 'main':
    train_dir = '../data/images/train'
    save_path = '../data/images/train.txt'
    save_data_path(data_dir, save_path)

2.2 圖片數據集轉換爲 TFRecord 格式

整體轉換爲本地訓練數據集創建 TFRecord 格式的訓練數據集可以按照以下四個步驟進行:

  1. 準備數據並使用 tf.io.TFRecordWriter 創建一個 TFRecordWriter 對象

  2. 確定本地數據集的 Features,並使用以下三個函數:tf.train.Example()tf.train.Features()tf.train.Feature(),將單個樣例的 Features 轉換成一個 record

  3. 使用 writer.write(record) 方法將單個樣例的 record 寫入 TFRecord 文件

  4. 等到所有樣例都寫入 TFRecord 文件後,關閉 TFRecordWriter 對象

import os

from tqdm import tqdm
from PIL import Image


def read_dataset(filepath):
    '''
    :description: 從本地文件中讀取訓練集中所有樣本的路徑及其對應標籤
    :param filepath: 本地文件路徑
    :return: 無返回值
    '''
    with open(filepath, 'r') as f:
        trainset = []
        lines = f.readlines()
        for line in lines:
            line = line.strip('\n')
            image_path = line.split()[0]
            label = int(line.split()[1])
            trainset.append([image_path, label])
        return trainset
    

def encode_to_tfrecord(filepath, record_name, save_dir):
    '''
    :description: 將數據集整體轉換爲 tfrecords 格式
    :param filepath: 保存着所有樣本路徑的文件
    :param record_name: 待保存的 tfrecord 文件的文件名
    :param save_dir: tfrecord 文件保存的目錄
    :return: None
    '''
    record_path = os.path.join(save_dir, record_name + '.tfrecords')
    if os.path.exists(record_path):
        os.remove(record_path)

    trainset = read_dataset(filepath)
    writer = tf.python_io.TFRecordWriter(record_path)
    pbar = tqdm(trainset)
    for train_data in pbar:
        try:
            image = Image.open(train_data[0])
            image_raw = image.tobytes()
            label = train_data[1]
            example = tf.train.Example(
                features=tf.train.Features(feature={
                    'image_raw': tf.train.Feature(
                        bytes_list=tf.train.BytesList(
                            value=[image_raw])),
                    'label': tf.train.Feature(
                        int64_list=tf.train.Int64List(
                            value=[label]))
                })
            )
            writer.write(example.SerializeToString())

        except IOError:
            print('could not read:', train_data[0])
        pbar.set_description('transforming: {}'.format(train_data[0].split('/')[-1]))
    writer.close()

    
if __name__ == 'main':
    encode_to_tfrecord(filepath='../data/images/train.txt', 
                       record_name='cat_dog', 
                       save_dir='../data/tfrecords')

程序運行結果:經過大約11分31秒的時間,25000 幅圖像被轉換轉成了 TFRecord 格式的訓練集,轉換得到的 TFRecord 文件大小爲 11.34 GB。

2.3 圖片地址集轉換爲 TFRecord 格式

地址集轉換爲 TFRecord 的步驟和整體轉換幾乎一致,只不過本地數據集的 Features 有點不一樣。在整體轉換中,單個樣例的 Features 分別爲圖像數據和標籤,而本次轉換中,單個樣例的 Features 則是圖像的本地地址和標籤,轉換過程如下:

import os

from PIL import Image
from tqdm import tqdm


def read_dataset(filepath):
    '''
    :description: 從本地文件中讀取訓練集中所有樣本的路徑及其對應標籤
    :param filepath: 本地文件路徑
    :return: 無返回值
    '''
    with open(filepath, 'r') as f:
        trainset = []
        lines = f.readlines()
        for line in lines:
            line = line.strip('\n')
            image_path = line.split()[0]
            label = int(line.split()[1])
            trainset.append([image_path, label])
        return trainset
    

def convert_to_tfrecord(filepath, record_name, 
                        save_dir='./data/tfrecords', encoding='utf-8'):
    '''
    :description: 使用 gfile 方式將數據集轉換爲 tfrecords 格式
    :param filepath: 保存着數據集路徑的文件
    :param record_name: 將要保存的 tfrecord 文件的文件名
    :param save_dir: tfrecord 文件保存的目錄
    :param encoding: 字符串轉換爲字節類型的編碼方式
    :return: None
    '''
    record_path = os.path.join(save_dir, record_name + '.tfrecords')
    if os.path.exists(record_path):
        os.remove(record_path)
    writer = tf.python_io.TFRecordWriter(record_path)
    trainset = utils.read_dataset(filepath)

    pbar = tqdm(trainset)
    for train_data in pbar:
        pbar.set_description('transforming {}'.format(train_data[0].split('/')[-1]))
        image_path = bytes(train_data[0], encoding=encoding)
        image_label = train_data[1]
        example = tf.train.Example(
            features=tf.train.Features(feature={
                'image_path': tf.train.Feature(
                    bytes_list=tf.train.BytesList(
                        value=[image_path])),
                'image_label': tf.train.Feature(
                    int64_list=tf.train.Int64List(
                        value=[image_label]))
            }))
        writer.write(example.SerializeToString())
    writer.close()
    
    
if __name__ == 'main':
    convert_to_tfrecord(filepath='../data/images/train.txt',
                        record_name='cat_dog', 
                        save_dir='../data/tfrecords')

程序運行結果:經過大約 4 秒的時間, 25000 幅圖像的地址集轉換成了 TFRecord 格式,轉換後得到的 TFRecord 文件大小爲 2.9 MB。

2.4 小結

對比以上兩種方式可以發現:

  1. 轉換得到的 TFRecord 格式的數據集大小將遠遠大於原始數據集的大小;

  2. 整體數據集轉換爲 TFRecord 格式耗費時間較長;

對於大規模數據集而言,將其轉換爲 TFRecord 格式是一個非常浩大的工程,而且往往由於轉換後的TFRecord 格式的數據集容量過於龐大,後續的加載和讀取將耗費更多的資源,從而引起一系列問題。

因此,工程上,通常選擇將大規模數據集的地址集轉換爲 TFRecord 格式,每次直接讀取生成 batch 後的地址,並通過這些地址,找到訓練樣本。

3. 讀取 TFRecord 文件

TFRecord 格式的數據集生成完成後即可用於神經網絡的訓練。在訓練時,通常需要從 TFRecord 格式的數據集中獲取訓練樣本,從 TFRecord 格式的數據集中獲取訓練樣本按照以下五個步驟進行:

  1. 獲取文件列表,並使用 tf.train.string_input_producer() 方法創建文件名隊列 filename_queue。(說明:當數據集較大時,通常創建多個 TFRecord 文件來保存數據集)

  2. 使用 tf.TFRecordReader() 創建一個 TFRecordReader 對象。

  3. 使用 reader.read() 方法從文件名隊列中讀取樣本。

  4. 使用 tf.io.parse_single_example() 方法解析出單個樣本,並使用 tf.reshape() 對樣本進行 reshape。

  5. 使用 tf.train.batch() 或者 tf.train.shuffle_batch() 方法將隊列中解析到的樣本組合成一個批次

def decode_from_tfrecords(filename_list, batch_size):
    filename_queue = tf.train.string_input_producer(filename_list)
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.io.parse_single_example(serialized_example, features={
        'image_raw': tf.FixedLenFeature([], dtype=tf.string),
        'label': tf.FixedLenFeature([], tf.int64),
    })
    image = tf.decode_raw(features['image_raw'], tf.uint8)
    image = tf.reshape(image, [227, 227, 3])
    label = tf.cast(features['label'], tf.int32)

    min_after_dequeue = 1000
    capacity = min_after_dequeue * 3 + batch_size
    image_batch, label_batch = tf.train.shuffle_batch([image, label], batch_size,
                                                      min_after_dequeue=min_after_dequeue,
                                                      capacity=capacity,
                                                      num_threads=3)
    return image_batch, label_batch

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章