先記錄一下讀tfrecord的方式
def parser(record, shape=[224, 224, 1]):
features = tf.parse_single_example(record, features={
'label': tf.FixedLenFeature([], tf.int64),
'img': tf.FixedLenFeature([], tf.string), })
img = tf.decode_raw(features["img"], tf.uint8) # 注意在這裏只能是tf.uint8,tf.float32會報錯
img = tf.reshape(img, shape)
img = tf.div(tf.to_float(img), 255.0)
label = tf.cast(features["label"], tf.int64)
return img, label
if __name__=="__main__":
train_data = tf.data.TFRecordDataset(train_records)
train_data = train_data.map(parser)
train_data = train_data.repeat()
train_data = train_data.batch(BATCH_SIZE)
train_data = train_data.shuffle(BATCH_SIZE)
train_iter = train_data.make_initializable_iterator()
next_train = train_iter.get_next()
以上是我原本寫的代碼,但是經過使用發現了很多問題。
雖然調用的shuffle但是數據並未打亂,然後在這篇博客的講解中,我注意到
對於這些方法的調用順序,是會對數據鏈造成影響的。
比如上述代碼,repeat->batch->shuffle
這樣就是,先repeat,repeat的參數等同於epoch的值,可以缺省
然後提取一個batch,最後再設置shuffle的buffersize
這樣相當於沒有進行shuffle操作,數據已經取出來了,再建立buffer,打亂其中的數據,也不會影響到已取出的。
遂將代碼改成
if __name__=="__main__"
train_dataset = tf.data.TFRecordDataset(train_files)
train_dataset = train_dataset.map(parser)
train_dataset = train_dataset.shuffle(buffer_size=200)
train_dataset = train_dataset.batch(100)
train_dataset = train_dataset.repeat()
train_iter = train_dataset.make_initializable_iterator()
next_train = train_iter.get_next()
先設置buffer_size=200,打亂其中數據,再提取一個batch=100的數據,最後執行repeat。
這樣一來數據的順序就被打亂了。
做等訓練結果。