tensorflow中的dataset

先记录一下读tfrecord的方式

def parser(record, shape=[224, 224, 1]):
    features = tf.parse_single_example(record, features={
        'label': tf.FixedLenFeature([], tf.int64),
        'img': tf.FixedLenFeature([], tf.string), })
    img = tf.decode_raw(features["img"], tf.uint8)  # 注意在这里只能是tf.uint8,tf.float32会报错
    img = tf.reshape(img, shape)
    img = tf.div(tf.to_float(img), 255.0)
    label = tf.cast(features["label"], tf.int64)
    return img, label
if __name__=="__main__":
	train_data = tf.data.TFRecordDataset(train_records)
    train_data = train_data.map(parser)
    train_data = train_data.repeat()
    train_data = train_data.batch(BATCH_SIZE)
    train_data = train_data.shuffle(BATCH_SIZE)
    train_iter = train_data.make_initializable_iterator()
    next_train = train_iter.get_next()

以上是我原本写的代码,但是经过使用发现了很多问题。
虽然调用的shuffle但是数据并未打乱,然后在这篇博客的讲解中,我注意到
对于这些方法的调用顺序,是会对数据链造成影响的。

比如上述代码,repeat->batch->shuffle
这样就是,先repeat,repeat的参数等同于epoch的值,可以缺省
然后提取一个batch,最后再设置shuffle的buffersize

这样相当于没有进行shuffle操作,数据已经取出来了,再建立buffer,打乱其中的数据,也不会影响到已取出的。

遂将代码改成

if __name__=="__main__"
    train_dataset = tf.data.TFRecordDataset(train_files)
    train_dataset = train_dataset.map(parser)
    train_dataset = train_dataset.shuffle(buffer_size=200)
    train_dataset = train_dataset.batch(100)
    train_dataset = train_dataset.repeat()
    train_iter = train_dataset.make_initializable_iterator()
    next_train = train_iter.get_next()

先设置buffer_size=200,打乱其中数据,再提取一个batch=100的数据,最后执行repeat。
这样一来数据的顺序就被打乱了。
做等训练结果。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章