目錄
一、tf.data簡介
藉助tf.data,構建輸入管道(將數據加載到模型)。
tf.data在TensorFlow中引入兩個新的抽象類:tf.data.Dataset、tf.data.Iterator.
Dataset:創建和轉化datasets的基類。初始化dataset兩種方式:從內存讀取數據,從Python生成器讀取數據。
TextLineDataset:從text文件中讀取數據,創建dataset。
FTRecordDataset:從TFRecord文件中讀取數據,創建dataset。
FixedLengthRecordDataset:從二進制文件中讀取固定大小的記錄,創建dataset。
Iterator:獲取dataset中的元素。
二、讀取數據
1、從內存中讀取數據-numpy數組
適合小型數據集,將所有數據加載到numpy數組中,使用tf.data.Dataset.from_tensor_slices()創建Dataset。
# Load the training data into two NumPy arrays, for example using `np.load()`.
with np.load("/var/data/training_data.npy") as data:
features = data["features"]
labels = data["labels"]
# Assume that each row of `features` corresponds to the same row as `labels`.
assert features.shape[0] == labels.shape[0]
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
2、從文件中讀取數據
tf.data支持多種文件格式,可以處理那些不適合存儲在內存中的大型數據集。
通過tf.data.TFRecordDataset類,讀取tfrecord文件:
# Creates a dataset that reads all of the examples from two files.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
通過tf.data.TextLineDataset類,讀取文本文件:
filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
dataset = tf.data.TextLineDataset(filenames)
通過tf.contrib.data.CsvDataset類,讀取csv文件:
# Creates a dataset that reads all of the records from two CSV files, each with
# eight float columns
filenames = ["/var/data/file1.csv", "/var/data/file2.csv"]
record_defaults = [tf.float32] * 8 # Eight required float columns
dataset = tf.contrib.data.CsvDataset(filenames, record_defaults)
三、變換Dataset中的元素
變換Dataset中的元素方式,通常有:轉換map、批處理batch、次序混亂shuffle、處理多個週期repeat。
1、使用Dataset.map()預處理數據
Dataset.map(f)轉換將指定函數f應用於輸入數據集的每個元素來生成新數據集。
解析tf.train.Example協議緩衝區消息。許多輸入管道都從TFRecord格式的文件中提取tf.train.Example協議緩衝區消息,每個tf.train.Example記錄都包含一個或多個特徵,輸入管道通常將這些特徵轉換爲張量。
# Transforms a scalar string `example_proto` into a pair of a scalar string and
# a scalar integer, representing an image and its label, respectively.
def _parse_function(example_proto):
features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
"label": tf.FixedLenFeature((), tf.int64, default_value=0)}
parsed_features = tf.parse_single_example(example_proto, features)
return parsed_features["image"], parsed_features["label"]
# Creates a dataset that reads all of the examples from two files, and extracts
# the image and label features.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_function)
解碼圖片數據並調整其大小。在用真實的圖片數據訓練神經網絡時,通常將不同大小的圖片轉換爲通用大小,這樣就可以將他們批處理爲具有固定大小的數據。
# Reads an image from a file, decodes it into a dense tensor, and resizes it
# to a fixed shape.
def _parse_function(filename, label):
image_string = tf.read_file(filename)
image_decoded = tf.image.decode_jpeg(image_string)
image_resized = tf.image.resize_images(image_decoded, [28, 28])
return image_resized, label
# A vector of filenames.
filenames = tf.constant(["/var/data/image1.jpg", "/var/data/image2.jpg", ...])
# `labels[i]` is the label for the image in `filenames[i].
labels = tf.constant([0, 37, ...])
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_parse_function)
2、使用Dataset.batch()批處理數據集元素
Dataset.batch()將數據集中的n個連續元素堆疊爲一個元素。使用限制,對於每個組件i,所有元素的張量形狀都必須完全相同。
inc_dataset = tf.data.Dataset.range(100)
dec_dataset = tf.data.Dataset.range(0, -100, -1)
dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))
batched_dataset = dataset.batch(4)
iterator = batched_dataset.make_one_shot_iterator()
next_element = iterator.get_next()
print(sess.run(next_element)) # ==> ([0, 1, 2, 3], [ 0, -1, -2, -3])
print(sess.run(next_element)) # ==> ([4, 5, 6, 7], [-4, -5, -6, -7])
print(sess.run(next_element)) # ==> ([8, 9, 10, 11], [-8, -9, -10, -11])
3、使用Dataset.shuffle()隨機重排輸入數據
Dataset.shuffle()會維持一個固定大小的緩衝區,並從該緩衝區中隨機地選擇下一個元素。
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat()
4、使用Dataset.repeat()迭代數據集多個週期
Dataset.repeat()創建一個將輸入重複多個週期的數據集
四、創建Iterator訪問Dataset中的元素
讀取Dataset中值的方法是構建迭代器對象。通過此對象可以一次訪問數據集中的一個對象。
1、單次迭代器
單次迭代器Dataset.make_one_shot_iterator(),僅支持對數據集進行一次迭代,不需要顯示初始化。目前,單次迭代器是唯一易於Estimator搭配使用的類型。
dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
for i in range(100):
value = sess.run(next_element)
assert i == value
2、可初始化迭代器
可初始化迭代器Dataset.make_initallizable_iterator(),允許使用一個或多個tf.placeholder()張量參數化數據集的定義,顯示iterator.initializer初始化後,纔可以讀取元素。
max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
# Initialize an iterator over a dataset with 10 elements.
sess.run(iterator.initializer, feed_dict={max_value: 10})
for i in range(10):
value = sess.run(next_element)
assert i == value
# Initialize the same iterator over a dataset with 100 elements.
sess.run(iterator.initializer, feed_dict={max_value: 100})
for i in range(100):
value = sess.run(next_element)
assert i == value
參考資料:
https://developers.googleblog.com/2017/09/introducing-tensorflow-datasets.html
https://www.tensorflow.org/guide/datasets?hl=zh-cn#basic_mechanics