tf.data.Dataset.repeat、batch和shuffle

結論

操作 結果
先Shuffle 每個Repeat的順序都不一樣
後Shuffle 所有的Repeat順序都不會變
先Repeat後Batch 所有的Repeat數據放在一起再Batch,Batch_nums = ceil(data_nums x repeat_nums_/Batch_size)
先Batch後Repeat 先把數據按照Batch_size分完,該步驟重複Repeat_nums次

1. 導入

import matplotlib.pyplot as plt
import tensorflow as tf
tf.__version__

2. 讀取數據

titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
titanic_lines = tf.data.TextLineDataset(titanic_file)
Downloading data from https://storage.googleapis.com/tf-datasets/titanic/train.csv
32768/30874 [===============================] - 0s 0us/step

3. 計數器

lines = titanic_lines
counter = tf.data.experimental.Counter()

4. 實驗

4.1 Shuffle——》Repeat——》Batch
lines4 = lines.take(4)
dataset = tf.data.Dataset.zip((counter, lines4))
shuffled = dataset.shuffle(buffer_size=3).repeat(2).batch(3)

n = 0
for i,j in shuffled:
    print(n,':  ', i.numpy())
    n+=1
0 :   [1 3 2]
1 :   [0 2 0]
2 :   [1 3]
4.2 Shuffle——》Batch——》Repeat
lines4 = lines.take(4)
dataset = tf.data.Dataset.zip((counter, lines4))
shuffled = dataset.shuffle(buffer_size=3).batch(3).repeat(2)

n = 0
for i,j in shuffled:
    print(n,':  ', i.numpy())
    n+=1
0 :   [2 0 1]
1 :   [3]
2 :   [0 2 1]
3 :   [3]
4.3 Batch——》Repeat——》Shuffle
lines4 = lines.take(4)
dataset = tf.data.Dataset.zip((counter, lines4))
shuffled = dataset.batch(3).repeat(2).shuffle(buffer_size=3)

n = 0
for i,j in shuffled:
    print(n,':  ', i.numpy())
    n+=1
0 :   [0 1 2]
1 :   [3]
2 :   [0 1 2]
3 :   [3]
4.4 Repeat——》Batch——》Shuffle
lines4 = lines.take(4)
dataset = tf.data.Dataset.zip((counter, lines4))
shuffled = dataset.repeat(2).batch(3).shuffle(buffer_size=3)

n = 0
for i,j in shuffled:
    print(n,':  ', i.numpy())
    n+=1
0 :   [0 1 2]
1 :   [3 0 1]
2 :   [2 3]
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章