主要介紹TensorFlow 另一個數據處理的利器——TFRecord。
一、什麼是TFRecord ?
TFRecord 是 TensorFlow 中的數據集存儲格式。當我們將數據集整理成 TFRecord 格式後,TensorFlow 就可以高效地讀取和處理這些數據集,從而幫助我們更高效地進行大規模的模型訓練。
TFRecord 可以理解爲一系列序列化的 tf.train.Example 元素所組成的列表文件,而每一個 tf.train.Example 又由若干個 tf.train.Feature 的字典組成。形式如下:
# dataset.tfrecords
[
{ # example 1 (tf.train.Example)
'feature_1': tf.train.Feature,
...
'feature_k': tf.train.Feature
},
...
{ # example N (tf.train.Example)
'feature_1': tf.train.Feature,
...
'feature_k': tf.train.Feature
}
]
爲了將形式各樣的數據集整理爲 TFRecord 格式可按以下步驟:
1、讀取該數據元素到內存;
2、將該元素轉換爲 tf.train.Example 對象(每一個 tf.train.Example 由若干個 tf.train.Feature 的字典組成,因此需要先建立 Feature 的字典);
3、將該 tf.train.Example 對象序列化爲字符串,並通過一個預先定義的 tf.io.TFRecordWriter 寫入 TFRecord 文件。
而讀取 TFRecord 數據則可按照以下步驟:
1、通過 tf.data.TFRecordDataset 讀入原始的 TFRecord 文件(此時文件中的 tf.train.Example 對象尚未被反序列化),獲得一個 tf.data.Dataset 數據集對象;
2、通過 Dataset.map 方法,對該數據集對象中的每一個序列化的 tf.train.Example 字符串執行 tf.io.parse_single_example 函數,從而實現反序列化。
二、實戰例子理解,貓狗數據集轉換爲tfrecord文件,並讀取
貓狗數據集下載地址
# 將數據集存儲爲 TFRecord 文件
import tensorflow as tf
import os
data_dir = 'D:/深度學習/tensorflow2.0/catsAndDogs/datasets/'
train_cats_dir = data_dir + '/train/cats/'
train_dogs_dir = data_dir + '/train/dogs/'
tfrecord_file = data_dir + '/train/train.tfrecords'
train_cat_filenames = [train_cats_dir + filename for filename in os.listdir(train_cats_dir)]
train_dog_filenames = [train_dogs_dir + filename for filename in os.listdir(train_dogs_dir)]
train_filenames = train_cat_filenames + train_dog_filenames
train_labels = [0] * len(train_cat_filenames) + [1] * len(train_dog_filenames) # 將 cat 類的標籤設爲0,dog 類的標籤設爲1
# 迭代讀取每張圖片,建立 tf.train.Feature 字典和 tf.train.Example 對象,序列化並寫入 TFRecord 文件。
with tf.io.TFRecordWriter(tfrecord_file) as writer:
for filename, label in zip(train_filenames, train_labels):
image = open(filename, 'rb').read() # 讀取數據集圖片到內存,image 爲一個 Byte 類型的字符串
feature = { # 建立 tf.train.Feature 字典
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])), # 圖片是一個 Bytes 對象
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label])) # 標籤是一個 Int 對象
}
example = tf.train.Example(features=tf.train.Features(feature=feature)) # 通過字典建立 Example
writer.write(example.SerializeToString()) # 將Example序列化並寫入 TFRecord 文件
# 運行以上代碼,不出片刻,我們即可在 tfrecord_file 所指向的文件地址獲得一個 500MB 左右的 train.tfrecords 文件。
# 讀取 TFRecord 文件
raw_dataset = tf.data.TFRecordDataset(tfrecord_file) # 讀取 TFRecord 文件
feature_description = { # 定義Feature結構,告訴解碼器每個Feature的類型是什麼
'image': tf.io.FixedLenFeature([], tf.string),
'label': tf.io.FixedLenFeature([], tf.int64),
}
def _parse_example(example_string): # 將 TFRecord 文件中的每一個序列化的 tf.train.Example 解碼
feature_dict = tf.io.parse_single_example(example_string, feature_description)
feature_dict['image'] = tf.io.decode_jpeg(feature_dict['image']) # 解碼JPEG圖片
return feature_dict['image'], feature_dict['label']
dataset = raw_dataset.map(_parse_example)
# 運行以上代碼後,我們獲得一個數據集對象 dataset ,這已經是一個可以用於訓練的 tf.data.Dataset 對象了!
import matplotlib.pyplot as plt
for image, label in dataset:
plt.title('cat' if label == 0 else 'dog')
plt.imshow(image.numpy())
plt.show()