tensorflow中的dataset API

來源

1.數據導入

tf.data API可以讓你以簡單可複用的方式構建複雜的Input Pipeline。例如:一個圖片模型的Pipeline可能會聚合在一個分佈式文件系統中的多個文件,對每個圖片進行隨機擾動(random perturbations),接着將隨機選中的圖片合併到一個training batch中。一個文本模型的Pipeline可能涉及到:從原始文本數據中抽取特徵,將它們通過一個lookup table轉換成embedding identifiers,然後將不同的長度序列batch在一起。tf.data API可以很方便地以不同的數據格式處理大量的數據,以及處理複雜的轉換。

Dataset API引入了兩個新的抽象類到Tensorflow中:

  • tf.data.Dataset:表示一串元素(elements),其中每個元素包含了一或多個Tensor對象。例如:在一個圖片pipeline中,一個元素可以是單個訓練樣本,它們帶有一個表示圖片數據的tensors和一個label組成的pair。有兩種不同的方式創建一個dataset
    • 創建一個source (例如:Dataset.from_tensor_slices()), 從一或多個tf.Tensor對象中構建一個dataset
    • 應用一個transformation(例如:Dataset.batch()),從一或多個tf.data.Dataset對象上構建一個dataset
  • tf.data.Iterator:它提供了主要的方式來從一個dataset中抽取元素。通過Iterator.get_next() 返回的該操作會yields出Datasets中的下一個元素,作爲輸入pipeline和模型間的接口使用。最簡單的iterator是一個“one-shot iterator”,它與一個指定的Dataset相關聯,通過它來進行迭代。對於更復雜的使用,Iterator.initializer操作可以使用不同的datasets重新初始化(reinitialize)和參數化(parameterize)一個iterator ,例如,在同一個程序中通過training data和validation data迭代多次。

2.基本機制

這部分描述了創建不同Dataset和Iterator對象的機制,以及如何使用它們來抽取數據。

要想啓動一個input pipeline,你必須定義一個source。例如,爲了從內存中的一些tensors構建一個Dataset,你可以使用tf.data.Dataset.from_tensors() 以及tf.data.Dataset.from_tensor_slices()。另一種方法,如果你的輸入數據在磁盤上以推薦的TFRecord格式存儲,你可以構建一個tf.data.TFRecordDataset。一旦你有一個Dataset對象,通過在tf.data.Dataset對象上鍊式方法調用,你可以將它轉化成一個新的Dataset。例如,你可以使用per-element transformations,比如:Dataset.map(),(它會在每個元素上應用一個function),以及multi-element transformations,比如:Dataset.batch()。更多詳見api

從一個Dataset上消費values的最常用方法,是生成一個iterator對象,它提供了一次可以訪問dataset中的一個元素(例如:通過調用Dataset.make_one_shot_iterator())。tf.data.Iterator提供了兩個操作

  • Iterator.initializer:它允許你(re)initialize iterator的狀態
  • Iterator.get_next():它返回tf.Tensor對象,對應於指定的下一個元素。

2.1 Dataset結構

一個dataset由element組成,它們每個都具有相同的結構。一個元素包含了一或多個tf.Tensor對象,稱爲“components“。每個component都具有一個tf.DType:它表示在tensor中的元素的類型;以及一個tf.TensorShape:它表示每個元素的靜態shape。Dataset.output_types 和 Dataset.output_shapes 屬性允許你觀察到一個dataset元素的每個component內省的types和shapes。這些屬性的這種嵌套式結構(nested structure),映射到一個元素(它可以是單個tensor、一個tensors的tuple、一個tensors的嵌套式tuple)的結構上。例如:

dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
print(dataset1.output_types)  # ==> "tf.float32"
print(dataset1.output_shapes)  # ==> "(10,)"

dataset2 = tf.data.Dataset.from_tensor_slices(
   (tf.random_uniform([4]),
    tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)))
print(dataset2.output_types)  # ==> "(tf.float32, tf.int32)"
print(dataset2.output_shapes)  # ==> "((), (100,))"

dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
print(dataset3.output_types)  # ==> (tf.float32, (tf.float32, tf.int32))
print(dataset3.output_shapes)  # ==> "(10, ((), (100,)))"

爲一個元素(element)的每個component給定names很方便,例如,如果它們表示一個訓練樣本的不同features。除了tuples,你可以使用collections.namedtuple,或者一個將strings映射爲關於tensors的字典來表示一個Dataset的單個元素。

dataset = tf.data.Dataset.from_tensor_slices(
   {"a": tf.random_uniform([4]),
    "b": tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)})
print(dataset.output_types)  # ==> "{'a': tf.float32, 'b': tf.int32}"
print(dataset.output_shapes)  # ==> "{'a': (), 'b': (100,)}"

Dataset的轉換(transformations)支持任何結構的datasets。當使用Dataset.map(),Dataset.flat_map(),以及Dataset.filter()轉換時,它們會對每個element應用一個function,元素結構決定了函數的參數:

dataset1 = dataset1.map(lambda x: ...)

dataset2 = dataset2.flat_map(lambda x, y: ...)

# Note: Argument destructuring is not available in Python 3.
dataset3 = dataset3.filter(lambda x, (y, z): ...)

2.2 創建一個iterator

一旦你已經構建了一個Dataset來表示你的輸入數據,下一步是創建一個Iterator來訪問dataset的elements。Dataset API當前支持四種iterator,複雜度依次遞增:

  • one-shot
  • initializable
  • reinitializable
  • feedable

one-shot iterator是最簡單的iterator,它只支持在一個dataset上迭代一次的操作,不需要顯式初始化。One-shot iterators可以處理幾乎所有的己存在的基於隊列的input pipeline支持的情況,但它們不支持參數化(parameterization)。使用Dataset.range()示例如下:

dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

for i in range(100):
  value = sess.run(next_element)
  assert i == value

initializable iterator在使用它之前需要你返回一個顯式的iterator.initializer操作。雖然有些不便,但它允許你可以對dataset的定義進行參數化(parameterize),使用一或多個tf.placeholder() tensors:它們可以當你初始化iterator時被feed進去。繼續Dataset.range() 的示例:

max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

# Initialize an iterator over a dataset with 10 elements.
sess.run(iterator.initializer, feed_dict={max_value: 10})
for i in range(10):
  value = sess.run(next_element)
  assert i == value

# Initialize the same iterator over a dataset with 100 elements.
sess.run(iterator.initializer, feed_dict={max_value: 100})
for i in range(100):
  value = sess.run(next_element)
  assert i == value

reinitializable iterator可以從多個不同的Dataset對象處初始化。例如,你可能有一個training input pipeline(它對輸入圖片做隨機擾動來提高泛化能力);以及一個validation input pipeline(它會在未修改過的數據上進行預測的評估)。這些pipeline通常使用不同的Dataset對象,但它們具有相同的結構(例如:對每個component相同的types和shapes)

# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
    lambda x: x + tf.random_uniform([], -10, 10, tf.int64))
validation_dataset = tf.data.Dataset.range(50)

# A reinitializable iterator is defined by its structure. We could use the
# `output_types` and `output_shapes` properties of either `training_dataset`
# or `validation_dataset` here, because they are compatible.
iterator = tf.data.Iterator.from_structure(training_dataset.output_types,
                                   training_dataset.output_shapes)
next_element = iterator.get_next()

training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)

# Run 20 epochs in which the training dataset is traversed, followed by the
# validation dataset.
for _ in range(20):
  # Initialize an iterator over the training dataset.
  sess.run(training_init_op)
  for _ in range(100):
    sess.run(next_element)

  # Initialize an iterator over the validation dataset.
  sess.run(validation_init_op)
  for _ in range(50):
    sess.run(next_element)

feedable iterator可以與tf.placeholder一起使用,通過熟悉的feed_dict機制,來選擇在每次調用tf.Session.run所使用的Iterator,。它提供了與reinitializable iterator相同的功能,但當你在iterators間相互切換時,它不需要你去初始化iterator。例如:使用上述相同的training和validation樣本,你可以使用tf.data.Iterator.from_string_handle來定義一個feedable iterator,並允許你在兩個datasets間切換:

# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
    lambda x: x + tf.random_uniform([], -10, 10, tf.int64)).repeat()
validation_dataset = tf.data.Dataset.range(50)

# A feedable iterator is defined by a handle placeholder and its structure. We
# could use the `output_types` and `output_shapes` properties of either
# `training_dataset` or `validation_dataset` here, because they have
# identical structure.
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
    handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()

# You can use feedable iterators with a variety of different kinds of iterator
# (such as one-shot and initializable iterators).
training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()

# The `Iterator.string_handle()` method returns a tensor that can be evaluated
# and used to feed the `handle` placeholder.
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())

# Loop forever, alternating between training and validation.
while True:
  # Run 200 steps using the training dataset. Note that the training dataset is
  # infinite, and we resume from where we left off in the previous `while` loop
  # iteration.
  for _ in range(200):
    sess.run(next_element, feed_dict={handle: training_handle})

  # Run one pass over the validation dataset.
  sess.run(validation_iterator.initializer)
  for _ in range(50):
    sess.run(next_element, feed_dict={handle: validation_handle})

2.3 從一個iterator上消費values

Iterator.get_next()方法會返回一或多個tf.Tensor對象,對應於一個iterator的下一個element。每次這些tensors被評測時,它們會在底層的dataset中獲得下一個element的value。(注意:類似於Tensorflow中其它的有狀態對象,調用Iterator.get_next() 不會立即讓iterator前移。相反的,你必須使用Tensorflow表達式所返回的tf.Tensor對象,傳遞該表達式的結果給tf.Session.run(),來獲取下一個elements,並讓iterator前移)

如果iterator達到了dataset的結尾,執行Iterator.get_next() 操作會拋出一個tf.errors.OutOfRangeError。在這之後,iterator會以一個不可用的狀態存在,如果你想進一步使用必須重新初始化它。

dataset = tf.data.Dataset.range(5)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

# Typically `result` will be the output of a model, or an optimizer's
# training operation.
result = tf.add(next_element, next_element)

sess.run(iterator.initializer)
print(sess.run(result))  # ==> "0"
print(sess.run(result))  # ==> "2"
print(sess.run(result))  # ==> "4"
print(sess.run(result))  # ==> "6"
print(sess.run(result))  # ==> "8"
try:
  sess.run(result)
except tf.errors.OutOfRangeError:
  print("End of dataset")  # ==> "End of dataset"

一種常用的模式是,將”training loop”封裝到一個try-except塊中:

sess.run(iterator.initializer)
while True:
  try:
    sess.run(result)
  except tf.errors.OutOfRangeError:
    break

如果dataset的每個元素都具有一個嵌套的結構,Iterator.get_next()的返回值將會是以相同嵌套結構存在的一或多個tf.Tensor對象

dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
dataset2 = tf.data.Dataset.from_tensor_slices((tf.random_uniform([4]), tf.random_uniform([4, 100])))
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))

iterator = dataset3.make_initializable_iterator()

sess.run(iterator.initializer)
next1, (next2, next3) = iterator.get_next()

注意,對next1, next2, or next3的任意一個進行評估都會爲所有components進行iterator。一個iterator的一種常見consumer將包含在單個表達式中的所有components。

3.讀取輸入數據

3.1 消費Numpy arrays

如果所有的輸入數據都加載進內存,最簡單的方式是,從輸入數據中創建一個Dataset,並將它們轉換成tf.Tensor對象,並使用Dataset.from_tensor_slices()。

# Load the training data into two NumPy arrays, for example using `np.load()`.
with np.load("/var/data/training_data.npy") as data:
  features = data["features"]
  labels = data["labels"]

# Assume that each row of `features` corresponds to the same row as `labels`.
assert features.shape[0] == labels.shape[0]

dataset = tf.data.Dataset.from_tensor_slices((features, labels))

注意,上述的代碼段會將features arrays和labels arrays作爲tf.constant() 操作嵌套進你的TensorFlow graph中。這在小數據集上能良好運行,但會浪費內存——因爲array的內存會被拷貝多次————對於tf.GraphDef的protocol buffer,只可以運行2GB的內存限制。

# Load the training data into two NumPy arrays, for example using `np.load()`.
with np.load("/var/data/training_data.npy") as data:
  features = data["features"]
  labels = data["labels"]

# Assume that each row of `features` corresponds to the same row as `labels`.
assert features.shape[0] == labels.shape[0]

features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)

dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
# [Other transformations on `dataset`...]
dataset = ...
iterator = dataset.make_initializable_iterator()

sess.run(iterator.initializer, feed_dict={features_placeholder: features,
                                          labels_placeholder: labels})

3.2 消費TFRecord數據

Dataset API支持多種文件格式,因此你可以處理超過內存大小的大數據集。例如,TFRecord文件格式是一種簡單的面向記錄的二進制格式,許多TensorFlow應用都用它來做訓練數據。tf.data.TFRecordDataset類允許你在一或多個TFRecord文件的內容上進行流化,將它們作爲input pipeline的一部分:

# Creates a dataset that reads all of the examples from two files.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)

TFRecordDataset initializer的filenames參數,可以是一個string,也可以是一列string,或者關於strings的一個tf.Tensor。因此,如果你具有兩個文件集合,分別對應訓練數據和驗證數據,你可以使用一個tf.placeholder(tf.string)來表示filenames,並從合適的filenames上初始化一個iterator:

filenames = tf.placeholder(tf.string, shape=[None])
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)  # Parse the record into tensors.
dataset = dataset.repeat()  # Repeat the input indefinitely.
dataset = dataset.batch(32)
iterator = dataset.make_initializable_iterator()

# You can feed the initializer with the appropriate filenames for the current
# phase of execution, e.g. training vs. validation.

# Initialize `iterator` with training data.
training_filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
sess.run(iterator.initializer, feed_dict={filenames: training_filenames})

# Initialize `iterator` with validation data.
validation_filenames = ["/var/data/validation1.tfrecord", ...]
sess.run(iterator.initializer, feed_dict={filenames: validation_filenames})

3.3 消費文本數據

許多datasets以一或多個文本文件分佈。tf.data.TextLineDataset提供了一種簡單的方式來從文本文件中抽取行(lines)。給定一或多個filenames,一個TextLineDataset將爲這些文件的每行生成一個string型的element。與TFRecordDataset類似,TextLineDataset會接受filenames參數作爲一個tf.Tensor,因此你可以通過傳遞一個tf.placeholder(tf.string)對它參數化。

filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
dataset = tf.data.TextLineDataset(filenames)

缺省的,一個TextLineDataset會yields每個文件的所有行,這不是我們所希望的,例如,如果該文件使用一個header line開始,或包含註釋。這些行通過Dataset.skip() 和 Dataset.filter() 轉換被移去。爲了將這些轉換獨立地應用每個文件上,我們使用Dataset.flat_map() 來爲每個文件創建一個嵌套的Dataset。

filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]

dataset = tf.data.Dataset.from_tensor_slices(filenames)

# Use `Dataset.flat_map()` to transform each file as a separate nested dataset,
# and then concatenate their contents sequentially into a single "flat" dataset.
# * Skip the first line (header row).
# * Filter out lines beginning with "#" (comments).
dataset = dataset.flat_map(
    lambda filename: (
        tf.data.TextLineDataset(filename)
        .skip(1)
        .filter(lambda line: tf.not_equal(tf.substr(line, 0, 1), "#"))))

4.使用Dataset.map()預處理數據

通過在輸入數據集的每個element上應用一個給定的函數f,Dataset.map(f)變換會產生一個新的dataset。該函數f會接受tf.Tensor對象(它表示input中的單個element)作爲參數,並返回tf.Tensor對象(它表示在new dataset中的單個element)。它的實現使用了標準的TensorFlow操作來將一個element轉換成另一個。

本節包含了如何使用Dataset.map()的示例。

4.1 解析tf.Example protocol buffer messages

許多input pipelines會從一個TFRecord格式的文件中抽取tf.train.Example protocol buffer messages(例如:使用tf.python_io.TFRecordWriter)。每個tf.train.Example record包含一或多個”features”,input pipeline通常會將這些features轉換成tensors。

# Transforms a scalar string `example_proto` into a pair of a scalar string and
# a scalar integer, representing an image and its label, respectively.
def _parse_function(example_proto):
  features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
              "label": tf.FixedLenFeature((), tf.int32, default_value=0)}
  parsed_features = tf.parse_single_example(example_proto, features)
  return parsed_features["image"], parsed_features["label"]

# Creates a dataset that reads all of the examples from two files, and extracts
# the image and label features.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_function)

4.2 將圖片數據進行decoding,並resizing

當對真實世界的圖片數據訓練一個神經網絡時,經常需要將不同size的圖片轉換成同一size,因此,必須批量轉換成一個固定的size。

# Reads an image from a file, decodes it into a dense tensor, and resizes it
# to a fixed shape.
def _parse_function(filename, label):
  image_string = tf.read_file(filename)
  image_decoded = tf.image.decode_image(image_string)
  image_resized = tf.image.resize_images(image_decoded, [28, 28])
  return image_resized, label

# A vector of filenames.
filenames = tf.constant(["/var/data/image1.jpg", "/var/data/image2.jpg", ...])

# `labels[i]` is the label for the image in `filenames[i].
labels = tf.constant([0, 37, ...])

dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_parse_function)

4.3 使用tf.py_func()

出於性能的原因,我們鼓勵你去使用TensorFlow operations來預處理數據。然而,有時,當解析你的輸入數據時調用額外的python庫會很有用。可以通過在一個Dataset.map() 轉換上調用tf.py_func() operation來達到這一點。

import cv2

# Use a custom OpenCV function to read the image, instead of the standard
# TensorFlow `tf.read_file()` operation.
def _read_py_function(filename, label):
  image_decoded = cv2.imread(image_string, cv2.IMREAD_GRAYSCALE)
  return image_decoded, label

# Use standard TensorFlow operations to resize the image to a fixed shape.
def _resize_function(image_decoded, label):
  image_decoded.set_shape([None, None, None])
  image_resized = tf.image.resize_images(image_decoded, [28, 28])
  return image_resized, label

filenames = ["/var/data/image1.jpg", "/var/data/image2.jpg", ...]
labels = [0, 37, 29, 1, ...]

dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(
    lambda filename, label: tf.py_func(
        _read_py_function, [filename, label], [tf.uint8, label.dtype]))
dataset = dataset.map(_resize_function)

5.打包元素(Batching dataset elements)

5.1 簡單的batching

batching的最簡單方式是,將數據集上n個連續的elements進行stack成單個elements。Dataset.batch() 轉換可以精準地做到這一點,它使用與tf.stack() 操作相同的constraints,應用在元素的每個component上:例如,對於每個元素i,所有元素必須具有一個相同shape的tensor:

inc_dataset = tf.data.Dataset.range(100)
dec_dataset = tf.data.Dataset.range(0, -100, -1)
dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))
batched_dataset = dataset.batch(4)

iterator = batched_dataset.make_one_shot_iterator()
next_element = iterator.get_next()

print(sess.run(next_element))  # ==> ([0, 1, 2,   3],   [ 0, -1,  -2,  -3])
print(sess.run(next_element))  # ==> ([4, 5, 6,   7],   [-4, -5,  -6,  -7])
print(sess.run(next_element))  # ==> ([8, 9, 10, 11],   [-8, -9, -10, -11])

5.2 使用padding打包tensors

上面的方法需要相同的size。然而,許多模型(比如:序列模型)的輸入數據的size多種多樣(例如:序列具有不同的長度)爲了處理這種情況,Dataset.padded_batch() 轉換允許你將不同shape的tensors進行batch,通過指定一或多個dimensions,在其上進行pad。

dataset = tf.data.Dataset.range(100)
dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x))
dataset = dataset.padded_batch(4, padded_shapes=[None])

iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

print(sess.run(next_element))  # ==> [[0, 0, 0], [1, 0, 0], [2, 2, 0], [3, 3, 3]]
print(sess.run(next_element))  # ==> [[4, 4, 4, 4, 0, 0, 0],
                               #      [5, 5, 5, 5, 5, 0, 0],
                               #      [6, 6, 6, 6, 6, 6, 0],
                               #      [7, 7, 7, 7, 7, 7, 7]]

Dataset.padded_batch() 轉換允許你爲每個component的每個dimension設置不同的padding,它可以是可變的長度(在樣本上指定None即可)或恆定長度。你可以對padding值(缺省爲0.0)進行override。

6.訓練工作流(Training workflows)

6.1 處理多個epochs

Dataset API提供了兩種主要方式來處理相同數據的多個epochs。

最簡單的方式是,在一個dataset上使用Dataset.repeat()轉換進行多輪迭代。例如:創建一個dataset,並repeat它的輸入10個epochs。

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.repeat(10)
dataset = dataset.batch(32)

使用無參數的Dataset.repeat() 會不斷重複input 。Dataset.repeat() 轉換將它的參數進行連接,無需一輪的結束處以及下一輪的開始處發出信號。

如果你想在每一輪的結尾接收到一個信號,你可以編寫一個training loop,在dataset的結尾處捕獲tf.errors.OutOfRangeError。在那時刻,你可以收集到該輪的一些統計信息(例如:validation error)

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.batch(32)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

# Compute for 100 epochs.
for _ in range(100):
  sess.run(iterator.initializer)
  while True:
    try:
      sess.run(next_element)
    except tf.errors.OutOfRangeError:
      break

  # [Perform end-of-epoch calculations here.]

6.2 對輸入數據進行random shuffling

Dataset.shuffle() 轉換會與tf.RandomShuffleQueue使用相同的算法對輸入數據集進行隨機shuffle:它會維持一個固定大小的buffer,並從該buffer中隨機均勻地選擇下一個元素:

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat()

6.3 使用高級API

tf.train.MonitoredTrainingSession API可以簡化分佈式設置下運行的Tensorflow的許多方面。當訓練完成時,MonitoredTrainingSession使用 tf.errors.OutOfRangeError來發射信號,因此爲了配合Dataset API使用它,我們推薦使用Dataset.make_one_shot_iterator()。例如:

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat(num_epochs)
iterator = dataset.make_one_shot_iterator()

next_example, next_label = iterator.get_next()
loss = model_function(next_example, next_label)

training_op = tf.train.AdagradOptimizer(...).minimize(loss)

with tf.train.MonitoredTrainingSession(...) as sess:
  while not sess.should_stop():
    sess.run(training_op)

爲了在tf.estimator.Estimator的input_fn使用一個Dataset,我們推薦使用Dataset.make_one_shot_iterator()。例如:

def dataset_input_fn():
  filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
  dataset = tf.data.TFRecordDataset(filenames)

  # Use `tf.parse_single_example()` to extract data from a `tf.Example`
  # protocol buffer, and perform any additional per-record preprocessing.
  def parser(record):
    keys_to_features = {
        "image_data": tf.FixedLenFeature((), tf.string, default_value=""),
        "date_time": tf.FixedLenFeature((), tf.int64, default_value=""),
        "label": tf.FixedLenFeature((), tf.int64,
                                    default_value=tf.zeros([], dtype=tf.int64)),
    }
    parsed = tf.parse_single_example(record, keys_to_features)

    # Perform additional preprocessing on the parsed data.
    image = tf.decode_jpeg(parsed["image_data"])
    image = tf.reshape(image, [299, 299, 1])
    label = tf.cast(parsed["label"], tf.int32)

    return {"image_data": image, "date_time": parsed["date_time"]}, label

  # Use `Dataset.map()` to build a pair of a feature dictionary and a label
  # tensor for each example.
  dataset = dataset.map(parser)
  dataset = dataset.shuffle(buffer_size=10000)
  dataset = dataset.batch(32)
  dataset = dataset.repeat(num_epochs)
  iterator = dataset.make_one_shot_iterator()

  # `features` is a dictionary in which each value is a batch of values for
  # that feature; `labels` is a batch of labels.
  features, labels = iterator.get_next()
  return features, labels

參考

官方tensorflow datasets

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章