【深度學習 走進tensorflow2.0】TensorFlow 2.0 常用模塊TFRecord

主要介紹TensorFlow 另一個數據處理的利器——TFRecord。

一、什麼是TFRecord ?
TFRecord 是 TensorFlow 中的數據集存儲格式。當我們將數據集整理成 TFRecord 格式後,TensorFlow 就可以高效地讀取和處理這些數據集,從而幫助我們更高效地進行大規模的模型訓練。

TFRecord 可以理解爲一系列序列化的 tf.train.Example 元素所組成的列表文件,而每一個 tf.train.Example 又由若干個 tf.train.Feature 的字典組成。形式如下:

# dataset.tfrecords
 [
     {   # example 1 (tf.train.Example)
         'feature_1': tf.train.Feature,
         ...
         'feature_k': tf.train.Feature
     },
     ...
     {   # example N (tf.train.Example)
        'feature_1': tf.train.Feature,
        ...
        'feature_k': tf.train.Feature
    }
]


爲了將形式各樣的數據集整理爲 TFRecord 格式可按以下步驟:
1、讀取該數據元素到內存;
2、將該元素轉換爲 tf.train.Example 對象(每一個 tf.train.Example 由若干個 tf.train.Feature 的字典組成,因此需要先建立 Feature 的字典);
3、將該 tf.train.Example 對象序列化爲字符串,並通過一個預先定義的 tf.io.TFRecordWriter 寫入 TFRecord 文件。

而讀取 TFRecord 數據則可按照以下步驟:
1、通過 tf.data.TFRecordDataset 讀入原始的 TFRecord 文件(此時文件中的 tf.train.Example 對象尚未被反序列化),獲得一個 tf.data.Dataset 數據集對象;
2、通過 Dataset.map 方法,對該數據集對象中的每一個序列化的 tf.train.Example 字符串執行 tf.io.parse_single_example 函數,從而實現反序列化。

二、實戰例子理解,貓狗數據集轉換爲tfrecord文件,並讀取
貓狗數據集下載地址


# 將數據集存儲爲 TFRecord 文件

import tensorflow as tf
import os

data_dir = 'D:/深度學習/tensorflow2.0/catsAndDogs/datasets/'
train_cats_dir = data_dir + '/train/cats/'
train_dogs_dir = data_dir + '/train/dogs/'


tfrecord_file = data_dir + '/train/train.tfrecords'

train_cat_filenames = [train_cats_dir + filename for filename in os.listdir(train_cats_dir)]
train_dog_filenames = [train_dogs_dir + filename for filename in os.listdir(train_dogs_dir)]


train_filenames = train_cat_filenames + train_dog_filenames
train_labels = [0] * len(train_cat_filenames) + [1] * len(train_dog_filenames)  # 將 cat 類的標籤設爲0,dog 類的標籤設爲1


# 迭代讀取每張圖片,建立 tf.train.Feature 字典和 tf.train.Example 對象,序列化並寫入 TFRecord 文件。
with tf.io.TFRecordWriter(tfrecord_file) as writer:
    for filename, label in zip(train_filenames, train_labels):
        image = open(filename, 'rb').read()     # 讀取數據集圖片到內存,image 爲一個 Byte 類型的字符串
        feature = {                             # 建立 tf.train.Feature 字典
            'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),  # 圖片是一個 Bytes 對象
            'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))   # 標籤是一個 Int 對象
        }
        example = tf.train.Example(features=tf.train.Features(feature=feature)) # 通過字典建立 Example
        writer.write(example.SerializeToString())   # 將Example序列化並寫入 TFRecord 文件

# 運行以上代碼,不出片刻,我們即可在 tfrecord_file 所指向的文件地址獲得一個 500MB 左右的 train.tfrecords 文件。

# 讀取 TFRecord 文件
raw_dataset = tf.data.TFRecordDataset(tfrecord_file)    # 讀取 TFRecord 文件
feature_description = { # 定義Feature結構,告訴解碼器每個Feature的類型是什麼
    'image': tf.io.FixedLenFeature([], tf.string),
    'label': tf.io.FixedLenFeature([], tf.int64),
}

def _parse_example(example_string): # 將 TFRecord 文件中的每一個序列化的 tf.train.Example 解碼
    feature_dict = tf.io.parse_single_example(example_string, feature_description)
    feature_dict['image'] = tf.io.decode_jpeg(feature_dict['image'])    # 解碼JPEG圖片
    return feature_dict['image'], feature_dict['label']


dataset = raw_dataset.map(_parse_example)
# 運行以上代碼後,我們獲得一個數據集對象 dataset ,這已經是一個可以用於訓練的 tf.data.Dataset 對象了!

import matplotlib.pyplot as plt
for image, label in dataset:
   plt.title('cat' if label == 0 else 'dog')
   plt.imshow(image.numpy())
   plt.show()
發佈了653 篇原創文章 · 獲贊 795 · 訪問量 188萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章