Tensorflow中的數據對象Dataset - shuffle()、repeat()、batch() 等用法

基礎概念

在tensorflow的官方文檔是這樣介紹Dataset數據對象的:

Dataset可以用來表示輸入管道元素集合(張量的嵌套結構)和“邏輯計劃“對這些元素的轉換操作。在Dataset中元素可以是向量,元組或字典等形式。
另外,Dataset需要配合另外一個類Iterator進行使用,Iterator對象是一個迭代器,可以對Dataset中的元素進行迭代提取。

看個簡單的示例:

#創建一個Dataset對象
dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])  
# 創建一個迭代器
iterator = dataset.make_one_shot_iterator()
# get_next()函數可以幫助我們從迭代器中獲取元素
element = iterator.get_next()
# 遍歷迭代器,獲取所有元素
with tf.Session() as sess:   
    for i in range(9):
        print(sess.run(element))

以上打印結果爲:1 2 3 4 5 6 7 8 9

Dataset方法

 

1.  from_tensor_slices

from_tensor_slices 用於創建dataset,其元素是給定張量的切片的元素。

函數形式:from_tensor_slices(tensors)

參數tensors:張量的嵌套結構,每個都在第0維中具有相同的大小。

具體例子

#創建切片形式的dataset
dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])
#創建一個迭代器
iterator = dataset.make_one_shot_iterator()
#get_next()函數可以幫助我們從迭代器中獲取元素
element = iterator.get_next()
#遍歷迭代器,獲取所有元素
with tf.Session() as sess:   
    for i in range(3):
       print(sess.run(element))

2.  from_tensors

創建一個Dataset包含給定張量的單個元素。

函數形式:from_tensors(tensors)

參數tensors:張量的嵌套結構。

具體例子

dataset = tf.data.Dataset.from_tensors([1,2,3,4,5,6,7,8,9])

iterator = concat_dataset.make_one_shot_iterator()

element = iterator.get_next()

with tf.Session() as sess:   
    for i in range(1):
       print(sess.run(element))
以上代碼運行結果:[1,2,3,4,5,6,7,8,9]

 結論: 即from_tensors是將tensors作爲一個整體進行操縱,而from_tensor_slices可以操縱tensors裏面的元素。

 

3.  from_generator

創建Dataset由其生成元素的元素generator。

函數形式:from_generator(generator,output_types,output_shapes=None,args=None)

參數generator:一個可調用對象,它返回支持該iter()協議的對象 。如果args未指定,generator則不得參數; 否則它必須採取與有值一樣多的參數args。
參數output_types:tf.DType對應於由元素生成的元素的每個組件的對象的嵌套結構generator。
參數output_shapes:tf.TensorShape 對應於由元素生成的元素的每個組件的對象 的嵌套結構generator
參數args:tf.Tensor將被計算並將generator作爲NumPy數組參數傳遞的對象元組。

具體例子

#定義一個生成器
def data_generator():
    dataset = np.array(range(9))    
    for i in dataset:        
        yield i
#接收生成器,並生產dataset數據結構        
dataset = tf.data.Dataset.from_generator(data_generator, (tf.int32))

iterator = concat_dataset.make_one_shot_iterator()

element = iterator.get_next()

with tf.Session() as sess:   
    for i in range(3):
       print(sess.run(element))
以上代碼運行結果:0 1 2

4.  batch

batch可以將數據集的連續元素合成批次。

函數形式:batch(batch_size,drop_remainder=False)

參數batch_size:表示要在單個批次中合併的此數據集的連續元素個數。
參數drop_remainder:表示在少於batch_size元素的情況下是否應刪除最後一批 ; 默認是不刪除。

具體例子:

#創建一個Dataset對象
dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])
'''合成批次'''
dataset=dataset.batch(3)
#創建一個迭代器
iterator = dataset.make_one_shot_iterator()
#get_next()函數可以幫助我們從迭代器中獲取元素
element = iterator.get_next()

#遍歷迭代器,獲取所有元素
with tf.Session() as sess:   
    for i in range(9):
       print(sess.run(element))

 

以上代碼運行結果爲:
[1 2 3]
[4 5 6]
[7 8 9]

即把目標對象合成3個批次,返回的對象是傳入Dataset對象。

5.  concatenate

concatenate可以將兩個Dataset對象進行合併或連接.

函數形式:concatenate(dataset)

參數dataset:表示需要傳入的dataset對象。

具體例子:

#創建dataset對象
dataset_a=tf.data.Dataset.from_tensor_slices([1,2,3])
dataset_b=tf.data.Dataset.from_tensor_slices([4,5,6])
#合併dataset
concat_dataset=dataset_a.concatenate(dataset_b)
iterator = concat_dataset.make_one_shot_iterator()
element = iterator.get_next()

with tf.Session() as sess:
   for i in range(6):
       print(sess.run(element))
以上代碼運行結果:1 2 3 4 5 6

6.  filter

filter可以對傳入的dataset數據進行條件過濾.

函數形式:filter(predicate)

參數predicate:條件過濾函數

具體例子

dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])
#對dataset內的數據進行條件過濾
dataset=dataset.filter(lambda x:x>3)

iterator = dataset.make_one_shot_iterator()

element = iterator.get_next()

with tf.Session() as sess:    
    for i in range(6):
       print(sess.run(element))

以上代碼運行結果:4 5 6 7 8 9

 

7.  map

map可以將map_func函數映射到數據集

函數形式:flat_map(map_func,num_parallel_calls=None)

參數map_func:映射函數
參數num_parallel_calls:表示要並行處理的數字元素。如果未指定,將按順序處理元素。

具體例子

dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])
#進行map操作
dataset=dataset.map(lambda x:x+1)

iterator = dataset.make_one_shot_iterator()

element = iterator.get_next()

with tf.Session() as sess:   
    for i in range(6):
       print(sess.run(element))

以上代碼運行結果:2 3 4 5 6 7

 

8.  flat_map

flat_map可以將map_func函數映射到數據集(與map不同的是flat_map傳入的數據必須是一個dataset)。

函數形式:flat_map(map_func)

參數map_func:映射函數

具體例子

dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])
#進行flat_map操作
dataset=dataset.flat_map(lambda x:tf.data.Dataset.from_tensor_slices(x+[1]))
iterator = dataset.make_one_shot_iterator()
element = iterator.get_next()

with tf.Session() as sess:
   for i in range(6):
       print(sess.run(element))

以上代碼運行結果:2 3 4 5 6 7

 

9.  make_one_shot_iterator

創建Iterator用於枚舉此數據集的元素。(可自動初始化)

函數形式:make_one_shot_iterator()

具體例子

dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])

iterator = dataset.make_one_shot_iterator()

element = iterator.get_next()

with tf.Session() as sess:   
    for i in range(6):
       print(sess.run(element))

10.  make_initializable_iterator

創建Iterator用於枚舉此數據集的元素。(使用此函數前需先進行迭代器的初始化操作)

函數形式:make_initializable_iterator(shared_name=None)

參數shared_name:(可選)如果非空,則返回的迭代器將在給定名稱下共享同一設備的多個會話(例如,使用遠程服務器時)

具體例子

dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])

iterator = dataset.make_initializable_iterator()

element = iterator.get_next()

with tf.Session() as sess:   #對迭代器進行初始化操作
    sess.run(iterator.initializer)   
    for i in range(5):
        print(sess.run(element))

 

11.  padded_batch

將數據集的連續元素組合到填充批次中,此轉換將輸入數據集的多個連續元素組合爲單個元素。

函數形式:padded_batch(batch_size,padded_shapes,padding_values=None,drop_remainder=False)

參數batch_size:表示要在單個批次中合併的此數據集的連續元素數。
參數padded_shapes:嵌套結構tf.TensorShape或 tf.int64類似矢量張量的對象,表示在批處理之前應填充每個輸入元素的相應組件的形狀。任何未知的尺寸(例如,tf.Dimension(None)在一個tf.TensorShape或-1類似張量的物體中)將被填充到每個批次中該尺寸的最大尺寸。
參數padding_values:(可選)標量形狀的嵌套結構 tf.Tensor,表示用於各個組件的填充值。默認值0用於數字類型,空字符串用於字符串類型。
參數drop_remainder:(可選)一個tf.bool標量tf.Tensor,表示在少於batch_size元素的情況下是否應刪除最後一批 ; 默認行爲是不刪除較小的批處理。

具體例子

dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])

dataset=dataset.padded_batch(2,padded_shapes=[])

iterator = dataset.make_one_shot_iterator()

element = iterator.get_next()

with tf.Session() as sess:   
    for i in range(6):
       print(sess.run(element))

以上代碼運行結果:
[1 2]
[3 4]

 

12.  repeat  √

重複此數據集count次數

函數形式:repeat(count=None)

參數count:(可選)表示數據集應重複的次數。默認行爲(如果count是None或-1)是無限期重複的數據集。

具體例子

dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])#無限次重複dataset數據集dataset=dataset.repeat()

iterator = dataset.make_one_shot_iterator()

element = iterator.get_next()

with tf.Session() as sess:   
    for i in range(30,35):
       print(sess.run(element))

以上代碼運行結果: 1 2 3 4 5

13.  shard

將Dataset分割成num_shards個子數據集。這個函數在分佈式訓練中非常有用,它允許每個設備讀取唯一子集。

函數形式:shard( num_shards,index)

參數num_shards:表示並行運行的分片數。
參數index:表示工人索引。

 

14.  shuffle

隨機混洗數據集的元素。

函數形式:shuffle(buffer_size,seed=None,reshuffle_each_iteration=None)

參數buffer_size:表示新數據集將從中採樣的數據集中的元素數。
參數seed:(可選)表示將用於創建分佈的隨機種子。
參數reshuffle_each_iteration:(可選)一個布爾值,如果爲true,則表示每次迭代時都應對數據集進行僞隨機重組。(默認爲True。)

具體例子

dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])
#隨機混洗數據
dataset=dataset.shuffle(3)
iterator = dataset.make_one_shot_iterator()
element = iterator.get_next()

with tf.Session() as sess:   
    for i in range(30,35):
       print(sess.run(element))

以上代碼運行結果:3 2 4

 

15.  skip

生成一個跳過count元素的數據集。

函數形式:skip(count)

參數count:表示應跳過以形成新數據集的此數據集的元素數。如果count大於此數據集的大小,則新數據集將不包含任何元素。如果count 爲-1,則跳過整個數據集。

具體例子

dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])
#跳過前5個元素
dataset=dataset.skip(5)
iterator = dataset.make_one_shot_iterator()
element = iterator.get_next()

with tf.Session() as sess:   
    for i in range(30,35):
       print(sess.run(element))

以上代碼運行結果: 6 7 8

 

16.  take

提取前count個元素形成性數據集

函數形式:take(count)

參數count:表示應該用於形成新數據集的此數據集的元素數。如果count爲-1,或者count大於此數據集的大小,則新數據集將包含此數據集的所有元素。

具體例子

dataset = tf.data.Dataset.from_tensor_slices([1,2,2,3,4,5,6,7,8,9])
#提取前5個元素形成新數據
dataset=dataset.take(5)
iterator = dataset.make_one_shot_iterator()
element = iterator.get_next()

with tf.Session() as sess:   
    for i in range(30,35):
       print(sess.run(element))

以上代碼運行結果: 1 2 2

 

17.  zip

將給定數據集壓縮在一起

函數形式:zip(datasets)

參數datesets:數據集的嵌套結構。

具體例子

dataset_a=tf.data.Dataset.from_tensor_slices([1,2,3])
dataset_b=tf.data.Dataset.from_tensor_slices([2,6,8])
zip_dataset=tf.data.Dataset.zip((dataset_a,dataset_b))
iterator = dataset.make_one_shot_iterator()
element = iterator.get_next()

with tf.Session() as sess:
   for i in range(30,35):
       print(sess.run(element))

以上代碼運行結果:
(1, 2)
(2, 6)
(3, 8)

到這裏Dataset中大部分方法 都在這裏做了初步的解釋,當然這些方法的配合使用才能夠在建模過程中發揮大作用。

 

傳送門:

1. TensorFlow.org教程筆記(二) DataSets 快速入門

2. 使用tf.data.Dataset.from_tensor_slices五步加載數據集

3. tf.data.Dataset.from_tensor_slices中的shuffle()、repeat()、batch()用法

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章