[tensorflow]tf.data數據導入

目錄

一、tf.data簡介

二、讀取數據

1、從內存中讀取數據-numpy數組

2、從文件中讀取數據

三、變換Dataset中的元素

1、使用Dataset.map()預處理數據

2、使用Dataset.batch()批處理數據集元素

3、使用Dataset.shuffle()隨機重排輸入數據

4、使用Dataset.repeat()迭代數據集多個週期

四、創建Iterator訪問Dataset中的元素

1、單次迭代器

2、可初始化迭代器


一、tf.data簡介

藉助tf.data,構建輸入管道(將數據加載到模型)。

tf.data在TensorFlow中引入兩個新的抽象類:tf.data.Dataset、tf.data.Iterator.

Dataset:創建和轉化datasets的基類。初始化dataset兩種方式:從內存讀取數據,從Python生成器讀取數據。

TextLineDataset:從text文件中讀取數據,創建dataset。

FTRecordDataset:從TFRecord文件中讀取數據,創建dataset。

FixedLengthRecordDataset:從二進制文件中讀取固定大小的記錄,創建dataset。

Iterator:獲取dataset中的元素。

二、讀取數據

1、從內存中讀取數據-numpy數組

適合小型數據集,將所有數據加載到numpy數組中,使用tf.data.Dataset.from_tensor_slices()創建Dataset。

# Load the training data into two NumPy arrays, for example using `np.load()`.
with np.load("/var/data/training_data.npy") as data:
  features = data["features"]
  labels = data["labels"]

# Assume that each row of `features` corresponds to the same row as `labels`.
assert features.shape[0] == labels.shape[0]

dataset = tf.data.Dataset.from_tensor_slices((features, labels))

2、從文件中讀取數據

tf.data支持多種文件格式,可以處理那些不適合存儲在內存中的大型數據集。

通過tf.data.TFRecordDataset類,讀取tfrecord文件

# Creates a dataset that reads all of the examples from two files.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)

通過tf.data.TextLineDataset類,讀取文本文件

filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
dataset = tf.data.TextLineDataset(filenames)

通過tf.contrib.data.CsvDataset類,讀取csv文件

# Creates a dataset that reads all of the records from two CSV files, each with
# eight float columns
filenames = ["/var/data/file1.csv", "/var/data/file2.csv"]
record_defaults = [tf.float32] * 8   # Eight required float columns
dataset = tf.contrib.data.CsvDataset(filenames, record_defaults)

三、變換Dataset中的元素

變換Dataset中的元素方式,通常有:轉換map、批處理batch、次序混亂shuffle、處理多個週期repeat。

1、使用Dataset.map()預處理數據

Dataset.map(f)轉換將指定函數f應用於輸入數據集的每個元素來生成新數據集。

解析tf.train.Example協議緩衝區消息。許多輸入管道都從TFRecord格式的文件中提取tf.train.Example協議緩衝區消息,每個tf.train.Example記錄都包含一個或多個特徵,輸入管道通常將這些特徵轉換爲張量。

# Transforms a scalar string `example_proto` into a pair of a scalar string and
# a scalar integer, representing an image and its label, respectively.
def _parse_function(example_proto):
  features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
              "label": tf.FixedLenFeature((), tf.int64, default_value=0)}
  parsed_features = tf.parse_single_example(example_proto, features)
  return parsed_features["image"], parsed_features["label"]

# Creates a dataset that reads all of the examples from two files, and extracts
# the image and label features.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_function)

解碼圖片數據並調整其大小。在用真實的圖片數據訓練神經網絡時,通常將不同大小的圖片轉換爲通用大小,這樣就可以將他們批處理爲具有固定大小的數據。

# Reads an image from a file, decodes it into a dense tensor, and resizes it
# to a fixed shape.
def _parse_function(filename, label):
  image_string = tf.read_file(filename)
  image_decoded = tf.image.decode_jpeg(image_string)
  image_resized = tf.image.resize_images(image_decoded, [28, 28])
  return image_resized, label

# A vector of filenames.
filenames = tf.constant(["/var/data/image1.jpg", "/var/data/image2.jpg", ...])

# `labels[i]` is the label for the image in `filenames[i].
labels = tf.constant([0, 37, ...])

dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_parse_function)

2、使用Dataset.batch()批處理數據集元素

Dataset.batch()將數據集中的n個連續元素堆疊爲一個元素。使用限制,對於每個組件i,所有元素的張量形狀都必須完全相同。

inc_dataset = tf.data.Dataset.range(100)
dec_dataset = tf.data.Dataset.range(0, -100, -1)
dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))
batched_dataset = dataset.batch(4)

iterator = batched_dataset.make_one_shot_iterator()
next_element = iterator.get_next()

print(sess.run(next_element))  # ==> ([0, 1, 2,   3],   [ 0, -1,  -2,  -3])
print(sess.run(next_element))  # ==> ([4, 5, 6,   7],   [-4, -5,  -6,  -7])
print(sess.run(next_element))  # ==> ([8, 9, 10, 11],   [-8, -9, -10, -11])

3、使用Dataset.shuffle()隨機重排輸入數據

Dataset.shuffle()會維持一個固定大小的緩衝區,並從該緩衝區中隨機地選擇下一個元素。

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat()

4、使用Dataset.repeat()迭代數據集多個週期

Dataset.repeat()創建一個將輸入重複多個週期的數據集

四、創建Iterator訪問Dataset中的元素

讀取Dataset中值的方法是構建迭代器對象。通過此對象可以一次訪問數據集中的一個對象。

1、單次迭代器

單次迭代器Dataset.make_one_shot_iterator(),僅支持對數據集進行一次迭代,不需要顯示初始化。目前,單次迭代器是唯一易於Estimator搭配使用的類型。

dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

for i in range(100):
  value = sess.run(next_element)
  assert i == value

2、可初始化迭代器

可初始化迭代器Dataset.make_initallizable_iterator(),允許使用一個或多個tf.placeholder()張量參數化數據集的定義,顯示iterator.initializer初始化後,纔可以讀取元素。

max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

# Initialize an iterator over a dataset with 10 elements.
sess.run(iterator.initializer, feed_dict={max_value: 10})
for i in range(10):
  value = sess.run(next_element)
  assert i == value

# Initialize the same iterator over a dataset with 100 elements.
sess.run(iterator.initializer, feed_dict={max_value: 100})
for i in range(100):
  value = sess.run(next_element)
  assert i == value

 

 

 

參考資料:

https://developers.googleblog.com/2017/09/introducing-tensorflow-datasets.html

https://www.tensorflow.org/guide/datasets?hl=zh-cn#basic_mechanics

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章