import tensorflow as tf
def _int64_feature(value):
if type(value) != list:
value = [value]
return tf.train.Feature(int64_list=tf.train.Int64List(value = value)) def _bytes_feature(value):
if type(value) != list:
value = [value]
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
feature_config = {}
# open the TFRecords file
writer = tf.python_io.TFRecordWriter("./tfrecord.test")
# feature
feature_config['video_id'] = _int64_feature([3,2])
example = tf.train.Example(features=tf.train.Features(feature=feature_config))
# Serialize to string and write on the file
writer.write(example.SerializeToString())
feature_config['video_id'] = _int64_feature([3,2,1,2,3])
example = tf.train.Example(features=tf.train.Features(feature=feature_config))
# Serialize to string and write on the file
writer.write(example.SerializeToString())
feature_config['video_id'] = _int64_feature([3,2,9])
example = tf.train.Example(features=tf.train.Features(feature=feature_config))
# Serialize to string and write on the file
writer.write(example.SerializeToString())
writer.close()
with tf.Session() as sess:
# Create a list of filenames and pass it to a queue
filename_queue = tf.train.string_input_producer(["./tfrecord.test"], num_epochs=1)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
video_id = tf.feature_column.categorical_column_with_identity(
key='video_id', num_buckets=10, default_value=0)
video_id2 = tf.feature_column.indicator_column(video_id)
columns = [video_id2]
#columns = [video_id]
video_out = tf.train.shuffle_batch([serialized_example], batch_size=1, capacity=30, num_threads=1, min_after_dequeue=3)
features = tf.parse_example(video_out, features= tf.feature_column.make_parse_example_spec(columns))
#video = features['video_id']
#video = tf.cast(features['video_id'], tf.int32)
# indicator
heihei = tf.feature_column.input_layer(features,columns)
#video = features
# Initialize all global and local variables
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init_op)
# Create a coordinator and run all QueueRunner objects
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for batch_index in range(3):
#print video
x= sess.run([heihei])
print x
# Stop the threads
coord.request_stop()
# Wait for threads to stop
coord.join(threads)
sess.close()
tensorflow tfrecord save and read demo
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.