tensorflow中的dataset

先記錄一下讀tfrecord的方式

def parser(record, shape=[224, 224, 1]):
    features = tf.parse_single_example(record, features={
        'label': tf.FixedLenFeature([], tf.int64),
        'img': tf.FixedLenFeature([], tf.string), })
    img = tf.decode_raw(features["img"], tf.uint8)  # 注意在這裏只能是tf.uint8,tf.float32會報錯
    img = tf.reshape(img, shape)
    img = tf.div(tf.to_float(img), 255.0)
    label = tf.cast(features["label"], tf.int64)
    return img, label
if __name__=="__main__":
	train_data = tf.data.TFRecordDataset(train_records)
    train_data = train_data.map(parser)
    train_data = train_data.repeat()
    train_data = train_data.batch(BATCH_SIZE)
    train_data = train_data.shuffle(BATCH_SIZE)
    train_iter = train_data.make_initializable_iterator()
    next_train = train_iter.get_next()

以上是我原本寫的代碼,但是經過使用發現了很多問題。
雖然調用的shuffle但是數據並未打亂,然後在這篇博客的講解中,我注意到
對於這些方法的調用順序,是會對數據鏈造成影響的。

比如上述代碼,repeat->batch->shuffle
這樣就是,先repeat,repeat的參數等同於epoch的值,可以缺省
然後提取一個batch,最後再設置shuffle的buffersize

這樣相當於沒有進行shuffle操作,數據已經取出來了,再建立buffer,打亂其中的數據,也不會影響到已取出的。

遂將代碼改成

if __name__=="__main__"
    train_dataset = tf.data.TFRecordDataset(train_files)
    train_dataset = train_dataset.map(parser)
    train_dataset = train_dataset.shuffle(buffer_size=200)
    train_dataset = train_dataset.batch(100)
    train_dataset = train_dataset.repeat()
    train_iter = train_dataset.make_initializable_iterator()
    next_train = train_iter.get_next()

先設置buffer_size=200,打亂其中數據,再提取一個batch=100的數據,最後執行repeat。
這樣一來數據的順序就被打亂了。
做等訓練結果。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章