TFRecords文件實現不定長圖片和標籤的存儲和讀取感悟(1)(附完整代碼)

最近一段時間接觸到用tfrecord儲存數據和讀取,期間踩了數之不盡的坑,在消bug的路上艱難行走,所以在這裏記錄下我所遇見過的各種坑,望共勉。 

TFRecord是谷歌推薦的一種二進制文件格式,理論上它可以保存任何格式的信息。使用tfrecord時,實際上是先讀取原生數據,然後轉換成tfrecord格式,在存儲在硬盤上。以後使用數據時,就可以從tfrecord文件 解碼讀出。

TFRecords文件中包含了類型爲tf.train.Example的協議內存塊(protocol buffer),而在協議內存塊中又包含了字段features(tf.train.Features)。features中又包含了若干個feature,每一個feature是一個map,也就是key-value的鍵值對, key取值是String類型,而value是Feature類型的消息體,它包含三種,BytesList,FloatList和Int64List,它們都是列表list的形式。如下面的函數int64_feature和int64_list_feature,兩者最大的區別在於前者是value=[value]和後者的value=value, []表示列表。

def int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def int64_list_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
example = tf.train.Example(features=tf.train.Features(feature={
    'label': int64_list_feature(image_label),
    'image': bytes_feature(image),
    'h': int64_feature(shape[0]),
    'w': int64_feature(shape[1]),
    'c': int64_feature(shape[2])
}))

如上所定義的Example消息體,包含一張圖片image的信息,及其標籤label信息和shape大小信息(height, width, channel),在這裏和大多數博客裏不一樣的在於‘label’標籤,通常的數據標籤是一個整數,例如貓狗圖片,用‘0’表示貓,用‘1’表示狗,即使是多分類標籤,也可以用0-N來表示各個類別,而我們的圖像標籤是一串中文或者英文,長度不一,首先在字典中查找其對應下標,形成list數組。關於這個標籤的處理,所以在這裏提供了兩個解決方案(都是踩過的坑):

  • 方案一是將標籤list數組轉換成one-hot形式,不使用tensorflow的tf.one-hot表示,而是自己定義函數,最後使得每個類別標籤爲一個字典大小的向量,讀取時, feature中定義'label': tf.FixedLenFeature([VOCUBLARY_SIZE], tf.int64),如果不加大小VOCUBLARY_SIZE,會報錯
  • 方案二是針對tf.nn.ctc_loss中labels參數的SparseTensor稀疏張量的要求,上述方案一得到的雖然是一個類one-hot形式,終究不是稀疏張量,所以將讀取到的label直接傳給參數 時還是labels時,是要報錯,所以爲了得到稀疏向量,直接將標籤存儲,讀取時,feature中定義'label': tf.VarLenFeature(dtype=tf.int64),使用的是變長讀取,這樣得到的是一個稀疏張量SparseTensor

另一個特殊之處在於圖片shape信息的儲存,可以看到這裏不是直接存儲shape,而是分開存儲,因爲每一個圖片的尺寸大小不同,所以如果直接以shape的大小存儲,也同樣會報錯。所以在這裏定義了三個鍵值對。讀取image時,就可以使用讀取的h,w,c這三個數據reshape圖像,如果圖像是定長的,shape的大小就可以直接定義,例如shape=[224, 224, 3]等等。

h = tf.cast(image_features['h'], tf.int32)
w = tf.cast(image_features['w'], tf.int32)
c = tf.cast(image_features['c'], tf.int32)

image = tf.decode_raw(image_features['image'], tf.uint8)
image = tf.cast(image, tf.float32)
image = tf.reshape(image, shape=[h, w, c])

最後說到圖像image方面,現在面對是image尺寸不一,目的是要圖片的高要resize到同一大小,寬度不定長。因爲要讀取數據時,數據量巨大,程序每次運行時是需要分batch,而每一batch裏面要求大小一致,所以如果對image不處理,也是會報錯。

  •  第一種情況是我碰見的image數據,height大小一致,寬度不定長,這樣存儲時,不用對數據進行resize,只是數據讀取時,對每一個圖像reshape後,使用resized_image = tf.image.resize_image_with_crop_or_pad(image, target_height=32, target_width=max_width)對圖像進行填充,雖說有剪裁,但是剪裁後會影響結果,所以這裏max_width的設定要儘可能包含所有的圖片的寬度,這樣後面在對圖像進行reshape後resized_image = tf.reshape(resized_image, shape=[32, max_width, 3])就可以batch數據了
  • 第二種情況是碰見的image數據長寬各不一,所以在存儲前就需要對image進行resize,等比例縮小。後面處理和第一種情況類似

關於存儲成tfrecord的步驟和讀取tfrecord文件步驟,現在已經有很多博客進行詳細描述,我就不用過多贅述,下面粘貼的是完整的代碼。這個代碼得到的tfrecord文件要比原圖像文件大10倍左右,原圖像有2.2G左右,生成的tfrecord文件大約有22G,網上也有人給出答案,例如一個圖像共有h*w*c個像素,將圖片轉變成byte類型時,這些像素按順序存在一個二進制序列中,每個像素需要用一個的字符進行表示,所以這樣就使得圖像文件存儲變大,如何解決,tensorflow中提供了一個tf.gfile.FastGFile類,可以直接讀取圖像的bytes形式

tf.gfile.FastGFile(filename, 'rb').read()

'r'表示要從文件中讀取數據,‘b’表示要讀取二進制數據,但是由於我們數據的複雜性,所以就不再嘗試此種方法了。

import tensorflow as tf
from PIL import Image
import numpy as np
import os
import random
from config import CHAR_VECTOR
from config import NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
from config import NUM_EXAMPLES_PER_EPOCH_FOR_TEST

VOCUBLARY_SIZE = len(CHAR_VECTOR)


def resize_image(image):
    '''resize the size of image'''
    width, height = image.size
    ratio = 32.0 / float(height)
    image = image.resize((int(width * ratio), 32))
    return image


def generation_vocublary(CHAR_VECTOR):
    vocublary = {}
    index = 0
    for char in CHAR_VECTOR:
        vocublary[char] = index
        index = index + 1
    return vocublary


def int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def int64_list_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def generation_TFRecord(data_dir):
    vocublary = generation_vocublary(CHAR_VECTOR)

    image_name_list = []
    for file in os.listdir(data_dir):
        if file.endswith('.jpg'):
            image_name_list.append(file)

    random.shuffle(image_name_list)
    capacity = len(image_name_list)

    # 生成train tfrecord文件
    train_writer = tf.python_io.TFRecordWriter('./dataset/train_dataset.tfrecords')
    train_image_name_list = image_name_list[0:int(capacity * 0.9)]
    for train_name in train_image_name_list:
        train_image_label = []
        for s in train_name.strip('.jpg'):
            train_image_label.append(vocublary[s])

        train_image = Image.open(os.path.join(data_dir, train_name))
        train_image = resize_image(train_image)
        # print(image.size)
        train_image_array = np.asarray(train_image, np.uint8)
        train_shape = np.array(train_image_array.shape, np.int32)
        train_image = train_image.tobytes()

        train_example = tf.train.Example(features=tf.train.Features(feature={
            'label': int64_list_feature(train_image_label),
            'image': bytes_feature(train_image),
            'h': int64_feature(train_shape[0]),
            'w': int64_feature(train_shape[1]),
            'c': int64_feature(train_shape[2])
        }))
        train_writer.write(train_example.SerializeToString())
    train_writer.close()

    # 生成test tfrecord文件
    test_writer = tf.python_io.TFRecordWriter('./dataset/test_dataset.tfrecords')
    test_image_name_list = image_name_list[int(capacity * 0.9):capacity]
    for test_name in test_image_name_list:
        test_image_label = []
        for s in test_name.strip('.jpg'):
            test_image_label.append(vocublary[s])

        test_image = Image.open(os.path.join(data_dir, test_name))
        test_image = resize_image(test_image)
        # print(image.size)
        test_image_array = np.asarray(test_image, np.uint8)
        test_shape = np.array(test_image_array.shape, np.int32)
        test_image = test_image.tobytes()

        test_example = tf.train.Example(features=tf.train.Features(feature={
            'label': int64_list_feature(test_image_label),
            'image': bytes_feature(test_image),
            'h': int64_feature(test_shape[0]),
            'w': int64_feature(test_shape[1]),
            'c': int64_feature(test_shape[2])
        }))
        test_writer.write(test_example.SerializeToString())
    test_writer.close()


def read_tfrecord(filename, max_width, batch_size, train=True):
    filename_queue = tf.train.string_input_producer([filename])
    reader = tf.TFRecordReader()
    _, serialize_example = reader.read(filename_queue)
    image_features = tf.parse_single_example(serialized=serialize_example,
                                             features={
                                                 # 'label': tf.FixedLenFeature([VOCUBLARY_SIZE], tf.int64),
                                                 'label': tf.VarLenFeature(dtype=tf.int64),
                                                 'image': tf.FixedLenFeature([], tf.string),
                                                 'h': tf.FixedLenFeature([], tf.int64),
                                                 'w': tf.FixedLenFeature([], tf.int64),
                                                 'c': tf.FixedLenFeature([], tf.int64)
                                             })
    h = tf.cast(image_features['h'], tf.int32)
    w = tf.cast(image_features['w'], tf.int32)
    c = tf.cast(image_features['c'], tf.int32)

    image = tf.decode_raw(image_features['image'], tf.uint8)
    image = tf.cast(image, tf.float32)
    image = tf.reshape(image, shape=[h, w, c])
    resized_image = tf.image.resize_image_with_crop_or_pad(image, target_height=32, target_width=max_width)
    resized_image = tf.reshape(resized_image, shape=[32, max_width, 3])

    label = tf.cast(image_features['label'], tf.int32)

    min_fraction_of_example_in_queue = 0.4
    if train is True:
        min_queue_examples = int(min_fraction_of_example_in_queue * NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN)
        train_image_batch, train_label_batch = tf.train.shuffle_batch([resized_image, label],
                                                                      batch_size=batch_size,
                                                                      capacity=min_queue_examples + 3 * batch_size,
                                                                      min_after_dequeue=min_queue_examples,
                                                                      num_threads=32)
        return train_image_batch, train_label_batch
    else:
        min_queue_examples = int(min_fraction_of_example_in_queue * NUM_EXAMPLES_PER_EPOCH_FOR_TEST)
        test_image_batch, test_label_batch = tf.train.batch([resized_image, label],
                                                            batch_size=batch_size,
                                                            capacity=min_queue_examples + 3 * batch_size,
                                                            num_threads=32)
        return test_image_batch, test_label_batch


def index_to_word(result):
    return ''.join([CHAR_VECTOR[i] for i in result])


def main(argv):
    generation_TFRecord('./dataset/images')
    train_image, train_label = read_tfrecord('./dataset/train_dataset.tfrecords', 250, 32)
    test_image, test_label = read_tfrecord('./dataset/test_dataset.tfrecords', 250, 32)
    with tf.Session() as session:
        session.run(tf.group(tf.global_variables_initializer(),
                             tf.local_variables_initializer()))
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        image_train, label_train = session.run([train_image, train_label])
        print(image_train.shape)

        image_test, label_test = session.run([test_image, test_label])
        print(image_test.shape)

        for image, label in  zip(image_test, label_test):
            # 將array轉換成image
            img = Image.fromarray(image, 'RGB')
            img.save(index_to_word(label) + '.jpg')
            print(index_to_word(label))

        coord.request_stop()
        coord.join(threads=threads)


if __name__ == '__main__':
    tf.app.run()

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章