tf.data
模塊包含一系列類,讓您可以輕鬆地加載數據、操作數據並通過管道將數據傳送到模型中。本文檔通過兩個簡單的示例來介紹該 API:
- 從 Numpy 數組中讀取內存中的數據。
- 從 csv 文件中讀取行。
從 Numpy 數組中讀取內存中的數據
要開始使用 tf.data
,最簡單的方法是從內存中的數組中提取切片。
內置 Estimator 一章介紹了 iris_data.py
中的以下 train_input_fn
,它可以通過管道將數據傳輸到 Estimator 中:
def train_input_fn(features, labels, batch_size):
"""An input function for training"""
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
# Shuffle, repeat, and batch the examples.
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
# Return the dataset.
return dataset
我們來詳細瞭解一下。
參數
此函數需要三個參數。要求所賦值爲“數組”的參數能夠接受可通過 numpy.array
轉換成數組的幾乎任何值。其中存在一個例外,即對 Datasets
有特殊意義的 tuple
,稍後我們會發現這一點。
features
:包含原始輸入特徵的{‘feature_name’:array}
字典(或DataFrame
)。labels
:包含每個樣本的標籤的數組。batch_size
:表示所需批次大小的整數。
在 premade_estimator.py
中,我們使用 iris_data.load_data()
函數檢索了鳶尾花數據。您可以運行該函數並解壓結果,如下所示:
import iris_data
# Fetch the data
train, test = iris_data.load_data()
features, labels = train
然後,我們使用類似以下內容的行將此數據傳遞給了輸入函數:
batch_size=100
iris_data.train_input_fn(features, labels, batch_size)
下面我們詳細介紹一下 train_input_fn()
。
切片
首先,此函數會利用 tf.data.Dataset.from_tensor_slices
函數創建一個代表數組切片的 tf.data.Dataset
。系統會在第一個維度內對該數組進行切片。例如,mnist 訓練數據的數組的形狀爲 (60000, 28, 28)
。將該數組傳遞給 from_tensor_slices
會返回一個包含 60000 個切片的 Dataset
對象,其中每個切片都是一個 28x28 的圖像。
返回此 Dataset
的代碼如下所示:
train, test = tf.keras.datasets.mnist.load_data()
mnist_x, mnist_y = train
mnist_ds = tf.data.Dataset.from_tensor_slices(mnist_x)
print(mnist_ds)
這段代碼將輸出以下行,顯示數據集中條目的 shapes 和 dtypes。請注意,Dataset
不知道自己包含多少條目。
<TensorSliceDataset shapes: (28,28), types: tf.uint8>
上面的 Dataset
表示一組簡單的數組,但實際的數據集要比這複雜得多。Dataset
可以按照透明方式處理字典或元組(或 namedtuple
)的任何嵌套組合。
例如,在將鳶尾花 features
轉換爲標準 Python 字典後,您可以將數組字典轉換爲字典 Dataset
,如下所示:
dataset = tf.data.Dataset.from_tensor_slices(dict(features))
print(dataset)
<TensorSliceDataset
shapes: {
SepalLength: (), PetalWidth: (),
PetalLength: (), SepalWidth: ()},
types: {
SepalLength: tf.float64, PetalWidth: tf.float64,
PetalLength: tf.float64, SepalWidth: tf.float64}
>
我們可以看到,如果 Dataset
包含結構化元素,則 Dataset
的 shapes
和 types
將採用同一結構。此數據集包含所有類型爲 tf.float64
的標量字典。
鳶尾花 train_input_fn
的第一行使用相同的功能,但添加了另一層結構。它會創建一個包含 (features_dict, label)
對的數據集。
以下代碼顯示標籤是類型爲 int64
的標量:
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
print(dataset)
<TensorSliceDataset
shapes: (
{
SepalLength: (), PetalWidth: (),
PetalLength: (), SepalWidth: ()},
()),
types: (
{
SepalLength: tf.float64, PetalWidth: tf.float64,
PetalLength: tf.float64, SepalWidth: tf.float64},
tf.int64)>
操作
目前,Dataset
會按固定順序迭代數據一次,並且一次僅生成一個元素。它需要進一步處理纔可用於訓練。幸運的是,tf.data.Dataset
類提供了更好地準備訓練數據的方法。輸入函數的下一行就利用了其中的幾種方法:
# Shuffle, repeat, and batch the examples.
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
shuffle
方法使用一個固定大小的緩衝區,在條目經過時隨機化處理條目。在這種情況下,buffer_size
大於 Dataset
中樣本的數量,確保數據完全被隨機化處理(鳶尾花數據集僅包含 150 個樣本)。
repeat
方法會在結束時重啓 Dataset
。要限制週期數量,請設置 count
參數。
batch
方法會收集大量樣本並將它們堆疊起來以創建批次。這爲批次的形狀增加了一個維度。新的維度將添加爲第一個維度。以下代碼對之前的 MNIST Dataset
使用 batch
方法。這樣會產生一個包含表示 (28,28)
圖像堆疊的三維數組的 Dataset
:
print(mnist_ds.batch(100))
<BatchDataset
shapes: (?, 28, 28),
types: tf.uint8>
請注意,該數據集的批次大小是未知的,因爲最後一個批次具有的元素數量會減少。
在 train_input_fn
中,經過批處理之後,Dataset
包含元素的一維向量,其中每個標量之前如下所示:
print(dataset)
<TensorSliceDataset
shapes: (
{
SepalLength: (?,), PetalWidth: (?,),
PetalLength: (?,), SepalWidth: (?,)},
(?,)),
types: (
{
SepalLength: tf.float64, PetalWidth: tf.float64,
PetalLength: tf.float64, SepalWidth: tf.float64},
tf.int64)>
返回
此時,Dataset
包含 (features_dict, labels)
對。這是 train
和 evaluate
方法的預期格式,因此 input_fn
會返回相應的數據集。
使用 predict
方法時,可以/應該忽略 labels
。
讀取 CSV 文件
Dataset
類最常見的實際用例是流式傳輸磁盤上文件中的數據。tf.data
模塊包含各種文件閱讀器。我們來看看如何使用 Dataset
解析 csv 文件中的 Iris 數據集。
對 iris_data.maybe_download
函數的以下調用會根據需要下載數據,並返回所生成文件的路徑名:
import iris_data
train_path, test_path = iris_data.maybe_download()
iris_data.csv_input_fn
函數包含使用 Dataset
解析 csv 文件的備用實現。
我們來了解一下如何構建從本地文件讀取數據且兼容 Estimator 的輸入函數。
構建 Dataset
我們先構建一個 TextLineDataset
對象,實現一次讀取文件中的一行數據。然後,我們調用 skip
方法來跳過文件的第一行,此行包含標題,而非樣本:
ds = tf.data.TextLineDataset(train_path).skip(1)
構建 csv 行解析器
我們先構建一個解析單行的函數。
以下 iris_data.parse_line
函數會使用 tf.decode_csv
函數和一些簡單的 Python 代碼來完成此任務:
爲了生成必要的 (features, label)
對,我們必須解析數據集中的每一行。以下 _parse_line
函數會調用 tf.decode_csv
,以將單行解析爲特徵和標籤兩個部分。由於 Estimator 需要將特徵表示爲字典,因此我們依靠 Python 的內置 dict
和 zip
函數來構建此字典。特徵名稱是該字典的鍵。然後,我們調用字典的 pop
方法以從特徵字典中移除標籤字段:
# Metadata describing the text columns
COLUMNS = ['SepalLength', 'SepalWidth',
'PetalLength', 'PetalWidth',
'label']
FIELD_DEFAULTS = [[0.0], [0.0], [0.0], [0.0], [0]]
def _parse_line(line):
# Decode the line into its fields
fields = tf.decode_csv(line, FIELD_DEFAULTS)
# Pack the result into a dictionary
features = dict(zip(COLUMNS,fields))
# Separate the label from the features
label = features.pop('label')
return features, label
解析行
數據集提供很多用於在通過管道將數據傳送到模型的過程中處理數據的方法。最常用的方法是 map
,它會對 Dataset
的每個元素應用轉換。
map
方法會接受 map_func
參數,此參數描述了應該如何轉換 Dataset
中的每個條目。
因此,爲了在從 csv 文件中流式傳出行時對行進行解析,我們將 _parse_line
函數傳遞給 map
方法:
ds = ds.map(_parse_line)
print(ds)
<MapDataset
shapes: (
{SepalLength: (), PetalWidth: (), ...},
()),
types: (
{SepalLength: tf.float32, PetalWidth: tf.float32, ...},
tf.int32)>
現在,數據集包含 (features, label)
對,而不是簡單的標量字符串。
iris_data.csv_input_fn
函數的剩餘部分與 iris_data.train_input_fn
函數完全相同,後者在基本輸入部分中進行了介紹。
試試看
此函數可用於替換 iris_data.train_input_fn
。可使用此函數饋送 Estimator,如下所示:
train_path, test_path = iris_data.maybe_download()
# All the inputs are numeric
feature_columns = [
tf.feature_column.numeric_column(name)
for name in iris_data.CSV_COLUMN_NAMES[: -1]
]
# Build the estimator
est = tf.estimator.LinearClassifier(feature_columns,
n_classes = 3)# Train the estimator
batch_size = 100
est.train(
steps = 1000,
input_fn = lambda: iris_data.csv_input_fn(train_path, batch_size))
Estimator 要求 input_fn
不接受任何參數。爲了不受此限制約束,我們使用 lambda
來獲取參數並提供所需的接口。
總結
tf.data
模塊提供一系列類和函數,可用於輕鬆從各種來源讀取數據。此外,tf.data
還提供簡單而又強大的方法,用於應用各種標準和自定義轉換。
現在,您已經基本瞭解瞭如何高效地將數據加載到 Estimator 中。接下來,請查看下列文檔:
- 創建自定義 Estimator:展示瞭如何自行構建自定義
Estimator
模型。 - 低階 API 簡介:展示瞭如何使用 TensorFlow 的低階 API 直接嘗試
tf.data.Datasets
。 - 導入數據:詳細介紹了
Datasets
的其他功能。