TensorFlow Estimator 官方文檔之----Dataset for Estimator


tf.data 模塊包含一系列類,讓您可以輕鬆地加載數據、操作數據並通過管道將數據傳送到模型中。本文檔通過兩個簡單的示例來介紹該 API:

  • 從 Numpy 數組中讀取內存中的數據。
  • 從 csv 文件中讀取行。

從 Numpy 數組中讀取內存中的數據

要開始使用 tf.data,最簡單的方法是從內存中的數組中提取切片。

內置 Estimator 一章介紹了 iris_data.py 中的以下 train_input_fn,它可以通過管道將數據傳輸到 Estimator 中:

def train_input_fn(features, labels, batch_size):
    """An input function for training"""
    # Convert the inputs to a Dataset.
    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

    # Shuffle, repeat, and batch the examples.
    dataset = dataset.shuffle(1000).repeat().batch(batch_size)

    # Return the dataset.
    return dataset

我們來詳細瞭解一下。

參數

此函數需要三個參數。要求所賦值爲“數組”的參數能夠接受可通過 numpy.array 轉換成數組的幾乎任何值。其中存在一個例外,即對 Datasets 有特殊意義的 tuple,稍後我們會發現這一點。

  • features:包含原始輸入特徵的 {‘feature_name’:array} 字典(或 DataFrame)。
  • labels:包含每個樣本的標籤的數組。
  • batch_size:表示所需批次大小的整數。

premade_estimator.py 中,我們使用 iris_data.load_data() 函數檢索了鳶尾花數據。您可以運行該函數並解壓結果,如下所示:

import iris_data

# Fetch the data
train, test = iris_data.load_data()
features, labels = train

然後,我們使用類似以下內容的行將此數據傳遞給了輸入函數:

batch_size=100
iris_data.train_input_fn(features, labels, batch_size)

下面我們詳細介紹一下 train_input_fn()

切片

首先,此函數會利用 tf.data.Dataset.from_tensor_slices 函數創建一個代表數組切片的 tf.data.Dataset。系統會在第一個維度內對該數組進行切片。例如,mnist 訓練數據的數組的形狀爲 (60000, 28, 28)。將該數組傳遞給 from_tensor_slices 會返回一個包含 60000 個切片的 Dataset 對象,其中每個切片都是一個 28x28 的圖像。

返回此 Dataset 的代碼如下所示:

train, test = tf.keras.datasets.mnist.load_data()
mnist_x, mnist_y = train

mnist_ds = tf.data.Dataset.from_tensor_slices(mnist_x)
print(mnist_ds)

這段代碼將輸出以下行,顯示數據集中條目的 shapes 和 dtypes。請注意Dataset 不知道自己包含多少條目。

<TensorSliceDataset shapes: (28,28), types: tf.uint8>

上面的 Dataset 表示一組簡單的數組,但實際的數據集要比這複雜得多。Dataset 可以按照透明方式處理字典或元組(或 namedtuple)的任何嵌套組合。

例如,在將鳶尾花 features 轉換爲標準 Python 字典後,您可以將數組字典轉換爲字典 Dataset,如下所示:

dataset = tf.data.Dataset.from_tensor_slices(dict(features))
print(dataset)
<TensorSliceDataset

  shapes: {
    SepalLength: (), PetalWidth: (),
    PetalLength: (), SepalWidth: ()},

  types: {
      SepalLength: tf.float64, PetalWidth: tf.float64,
      PetalLength: tf.float64, SepalWidth: tf.float64}
>

我們可以看到,如果 Dataset 包含結構化元素,則 Datasetshapestypes 將採用同一結構。此數據集包含所有類型爲 tf.float64標量字典。

鳶尾花 train_input_fn 的第一行使用相同的功能,但添加了另一層結構。它會創建一個包含 (features_dict, label) 對的數據集。

以下代碼顯示標籤是類型爲 int64 的標量:

# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
print(dataset)
<TensorSliceDataset
    shapes: (
        {
          SepalLength: (), PetalWidth: (), 
          PetalLength: (), SepalWidth: ()},
        ()),

    types: (
        {
          SepalLength: tf.float64, PetalWidth: tf.float64, 
          PetalLength: tf.float64, SepalWidth: tf.float64},
        tf.int64)>

操作

目前,Dataset 會按固定順序迭代數據一次,並且一次僅生成一個元素。它需要進一步處理纔可用於訓練。幸運的是,tf.data.Dataset 類提供了更好地準備訓練數據的方法。輸入函數的下一行就利用了其中的幾種方法:

# Shuffle, repeat, and batch the examples.
dataset = dataset.shuffle(1000).repeat().batch(batch_size)

shuffle 方法使用一個固定大小的緩衝區,在條目經過時隨機化處理條目。在這種情況下,buffer_size 大於 Dataset 中樣本的數量,確保數據完全被隨機化處理(鳶尾花數據集僅包含 150 個樣本)。

repeat 方法會在結束時重啓 Dataset。要限制週期數量,請設置 count 參數。

batch 方法會收集大量樣本並將它們堆疊起來以創建批次。這爲批次的形狀增加了一個維度。新的維度將添加爲第一個維度。以下代碼對之前的 MNIST Dataset 使用 batch 方法。這樣會產生一個包含表示 (28,28) 圖像堆疊的三維數組的 Dataset

print(mnist_ds.batch(100))
<BatchDataset
  shapes: (?, 28, 28),
  types: tf.uint8>

請注意,該數據集的批次大小是未知的,因爲最後一個批次具有的元素數量會減少。

train_input_fn 中,經過批處理之後,Dataset 包含元素的一維向量,其中每個標量之前如下所示:

print(dataset)
<TensorSliceDataset
    shapes: (
        {
          SepalLength: (?,), PetalWidth: (?,),
          PetalLength: (?,), SepalWidth: (?,)},
        (?,)),

    types: (
        {
          SepalLength: tf.float64, PetalWidth: tf.float64,
          PetalLength: tf.float64, SepalWidth: tf.float64},
        tf.int64)>

返回

此時,Dataset 包含 (features_dict, labels) 對。這是 trainevaluate 方法的預期格式,因此 input_fn 會返回相應的數據集。

使用 predict 方法時,可以/應該忽略 labels

讀取 CSV 文件

Dataset 類最常見的實際用例是流式傳輸磁盤上文件中的數據。tf.data 模塊包含各種文件閱讀器。我們來看看如何使用 Dataset 解析 csv 文件中的 Iris 數據集。

iris_data.maybe_download 函數的以下調用會根據需要下載數據,並返回所生成文件的路徑名:

import iris_data
train_path, test_path = iris_data.maybe_download()

iris_data.csv_input_fn 函數包含使用 Dataset 解析 csv 文件的備用實現。

我們來了解一下如何構建從本地文件讀取數據且兼容 Estimator 的輸入函數。

構建 Dataset

我們先構建一個 TextLineDataset 對象,實現一次讀取文件中的一行數據。然後,我們調用 skip 方法來跳過文件的第一行,此行包含標題,而非樣本:

ds = tf.data.TextLineDataset(train_path).skip(1)

構建 csv 行解析器

我們先構建一個解析單行的函數。

以下 iris_data.parse_line 函數會使用 tf.decode_csv 函數和一些簡單的 Python 代碼來完成此任務:

爲了生成必要的 (features, label) 對,我們必須解析數據集中的每一行。以下 _parse_line 函數會調用 tf.decode_csv,以將單行解析爲特徵和標籤兩個部分。由於 Estimator 需要將特徵表示爲字典,因此我們依靠 Python 的內置 dictzip 函數來構建此字典。特徵名稱是該字典的鍵。然後,我們調用字典的 pop 方法以從特徵字典中移除標籤字段:

# Metadata describing the text columns
COLUMNS = ['SepalLength', 'SepalWidth',
           'PetalLength', 'PetalWidth',
           'label']
FIELD_DEFAULTS = [[0.0], [0.0], [0.0], [0.0], [0]]
def _parse_line(line):
    # Decode the line into its fields
    fields = tf.decode_csv(line, FIELD_DEFAULTS)

    # Pack the result into a dictionary
    features = dict(zip(COLUMNS,fields))

    # Separate the label from the features
    label = features.pop('label')

    return features, label

解析行

數據集提供很多用於在通過管道將數據傳送到模型的過程中處理數據的方法。最常用的方法是 map,它會對 Dataset 的每個元素應用轉換。

map 方法會接受 map_func 參數,此參數描述了應該如何轉換 Dataset 中的每個條目。

在這裏插入圖片描述

因此,爲了在從 csv 文件中流式傳出行時對行進行解析,我們將 _parse_line 函數傳遞給 map 方法:

ds = ds.map(_parse_line)
print(ds)
<MapDataset
shapes: (
    {SepalLength: (), PetalWidth: (), ...},
    ()),
types: (
    {SepalLength: tf.float32, PetalWidth: tf.float32, ...},
    tf.int32)>

現在,數據集包含 (features, label) 對,而不是簡單的標量字符串。

iris_data.csv_input_fn 函數的剩餘部分與 iris_data.train_input_fn 函數完全相同,後者在基本輸入部分中進行了介紹。

試試看

此函數可用於替換 iris_data.train_input_fn。可使用此函數饋送 Estimator,如下所示:

train_path, test_path = iris_data.maybe_download()

# All the inputs are numeric
feature_columns = [
    tf.feature_column.numeric_column(name)
    for name in iris_data.CSV_COLUMN_NAMES[: -1]
]

# Build the estimator
est = tf.estimator.LinearClassifier(feature_columns,
    n_classes = 3)# Train the estimator
batch_size = 100
est.train(
    steps = 1000,
    input_fn = lambda: iris_data.csv_input_fn(train_path, batch_size))

Estimator 要求 input_fn 不接受任何參數。爲了不受此限制約束,我們使用 lambda 來獲取參數並提供所需的接口。

總結

tf.data 模塊提供一系列類和函數,可用於輕鬆從各種來源讀取數據。此外,tf.data 還提供簡單而又強大的方法,用於應用各種標準和自定義轉換。

現在,您已經基本瞭解瞭如何高效地將數據加載到 Estimator 中。接下來,請查看下列文檔:

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章