一個簡單的寫TFRecord示例
import os
import tensorflow as tf
# tf version: 1.12
root_path = '/home/wuyanxue/Data/StandardTestImages/dataset/'
train_path = os.path.join(root_path, 'train.txt')
test_path = os.path.join(root_path, 'test.txt')
out_train_tfrecords = os.path.join(root_path, 'train.tfrecords')
out_test_tfrecords = os.path.join(root_path, 'test.tfrecords')
train_writer = tf.io.TFRecordWriter(out_train_tfrecords)
test_writer = tf.io.TFRecordWriter(out_test_tfrecords)
def serialize_example(origin, noise, darked, noise_darked):
feature = {
'origin': tf.train.Feature(float_list=tf.train.FloatList(value=origin)),
'noise': tf.train.Feature(float_list=tf.train.FloatList(value=noise)),
'darked': tf.train.Feature(float_list=tf.train.FloatList(value=darked)),
'noise_darked': tf.train.Feature(float_list=tf.train.FloatList(value=noise_darked)),
}
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
return example_proto.SerializeToString()
with open(train_path, 'r') as f:
s = f.readlines()
for x in s:
xx = x.strip().split(' ')
xx = [float(c) for c in xx]
origin = xx[:289]
noise = xx[289:289*2]
darked = xx[289*2:289*3]
noise_darked = xx[289*3:289*4]
example = serialize_example(origin, noise, darked, noise_darked)
train_writer.write(example)
train_writer.close()
with open(test_path, 'r') as f:
s = f.readlines()
for x in s:
xx = x.strip().split(' ')
xx = [float(c) for c in xx]
origin = xx[:289]
noise = xx[289:289*2]
darked = xx[289*2:289*3]
noise_darked = xx[289*3:289*4]
example = serialize_example(origin, noise, darked, noise_darked)
test_writer.write(example)
test_writer.close()
讀TFRecord示例
import os
import time
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import numpy as np
import tensorflow as tf
# tf version: 1.12
root_path = '/home/wuyanxue/Data/StandardTestImages/dataset/'
train_path = os.path.join(root_path, 'train.tfrecords')
def parse_func(example_proto):
feature_desc = {
# 不能傳默認值
# 'origin': tf.io.FixedLenFeature([289,], tf.float32, default_value=0.0),
# 傳默認值需要和形狀一致
# 'origin': tf.io.FixedLenFeature([289,], tf.float32, default_value=[0.]*289),
'origin': tf.io.FixedLenFeature([289,], tf.float32),
'noise': tf.io.FixedLenFeature([289,], tf.float32),
'darked': tf.io.FixedLenFeature([289,], tf.float32),
'noise_darked': tf.io.FixedLenFeature([289,], tf.float32),
}
features = tf.io.parse_single_example(example_proto, feature_desc)
# features = tf.io.parse_example(example_proto, feature_desc)
return features['noise_darked'], features['origin']
train_ds = tf.data.TFRecordDataset(train_path)
train_ds = train_ds.map(parse_func)
iterator = train_ds.make_one_shot_iterator()
batch_train_data_tf = iterator.get_next()
sess = tf.Session()
cc = sess.run(batch_train_data_tf)
print(type(cc))
實驗表明,TFRecord相比於原始的.txt文件,訓練速度有質變的提升。