TFRecord讀寫

一個簡單的寫TFRecord示例


import os
import tensorflow as tf

# tf version: 1.12
root_path = '/home/wuyanxue/Data/StandardTestImages/dataset/'
train_path = os.path.join(root_path, 'train.txt')
test_path = os.path.join(root_path, 'test.txt')

out_train_tfrecords = os.path.join(root_path, 'train.tfrecords')
out_test_tfrecords = os.path.join(root_path, 'test.tfrecords')

train_writer = tf.io.TFRecordWriter(out_train_tfrecords)
test_writer = tf.io.TFRecordWriter(out_test_tfrecords)

def serialize_example(origin, noise, darked, noise_darked):
    feature = {
        'origin': tf.train.Feature(float_list=tf.train.FloatList(value=origin)),
        'noise': tf.train.Feature(float_list=tf.train.FloatList(value=noise)),
        'darked': tf.train.Feature(float_list=tf.train.FloatList(value=darked)),
        'noise_darked': tf.train.Feature(float_list=tf.train.FloatList(value=noise_darked)),
    }
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()


with open(train_path, 'r') as f:
    s = f.readlines()
    for x in s:
        xx = x.strip().split(' ')
        xx = [float(c) for c in xx]
        origin = xx[:289]
        noise = xx[289:289*2]
        darked = xx[289*2:289*3]
        noise_darked = xx[289*3:289*4]
        example = serialize_example(origin, noise, darked, noise_darked)
        train_writer.write(example)

train_writer.close()

with open(test_path, 'r') as f:
    s = f.readlines()
    for x in s:
        xx = x.strip().split(' ')
        xx = [float(c) for c in xx]
        origin = xx[:289]
        noise = xx[289:289*2]
        darked = xx[289*2:289*3]
        noise_darked = xx[289*3:289*4]
        example = serialize_example(origin, noise, darked, noise_darked)
        test_writer.write(example)

test_writer.close()

讀TFRecord示例

import os
import time
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import numpy as np
import tensorflow as tf
# tf version: 1.12

root_path = '/home/wuyanxue/Data/StandardTestImages/dataset/'
train_path = os.path.join(root_path, 'train.tfrecords')

def parse_func(example_proto):
    feature_desc = {
        # 不能傳默認值
        # 'origin': tf.io.FixedLenFeature([289,], tf.float32, default_value=0.0),
      	# 傳默認值需要和形狀一致
        # 'origin': tf.io.FixedLenFeature([289,], tf.float32, default_value=[0.]*289),
        'origin': tf.io.FixedLenFeature([289,], tf.float32),
        'noise': tf.io.FixedLenFeature([289,], tf.float32),
        'darked': tf.io.FixedLenFeature([289,], tf.float32),
        'noise_darked': tf.io.FixedLenFeature([289,], tf.float32),
    }
    features = tf.io.parse_single_example(example_proto, feature_desc)
    # features = tf.io.parse_example(example_proto, feature_desc)
    return features['noise_darked'], features['origin']

train_ds = tf.data.TFRecordDataset(train_path)
train_ds = train_ds.map(parse_func)
iterator = train_ds.make_one_shot_iterator()
batch_train_data_tf = iterator.get_next()
sess = tf.Session()
cc = sess.run(batch_train_data_tf)
print(type(cc))

實驗表明,TFRecord相比於原始的.txt文件,訓練速度有質變的提升。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章