tensorflow tfrecord save and read demo


import tensorflow as tf

def _int64_feature(value):
    if type(value) != list:
        value = [value]
    return tf.train.Feature(int64_list=tf.train.Int64List(value = value)) def _bytes_feature(value):
    if type(value) != list:
        value = [value]
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

feature_config = {}

# open the TFRecords file
writer = tf.python_io.TFRecordWriter("./tfrecord.test")

# feature
feature_config['video_id'] = _int64_feature([3,2])
example = tf.train.Example(features=tf.train.Features(feature=feature_config))
# Serialize to string and write on the file
writer.write(example.SerializeToString())


feature_config['video_id'] = _int64_feature([3,2,1,2,3])
example = tf.train.Example(features=tf.train.Features(feature=feature_config))
# Serialize to string and write on the file
writer.write(example.SerializeToString())

feature_config['video_id'] = _int64_feature([3,2,9])
example = tf.train.Example(features=tf.train.Features(feature=feature_config))
# Serialize to string and write on the file
writer.write(example.SerializeToString())

writer.close()



with tf.Session() as sess:

    # Create a list of filenames and pass it to a queue
    filename_queue = tf.train.string_input_producer(["./tfrecord.test"], num_epochs=1)

    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)


    video_id = tf.feature_column.categorical_column_with_identity(
          key='video_id', num_buckets=10, default_value=0)

    video_id2 = tf.feature_column.indicator_column(video_id)

    columns = [video_id2]
    #columns = [video_id]

        video_out = tf.train.shuffle_batch([serialized_example], batch_size=1, capacity=30, num_threads=1, min_after_dequeue=3)

    features = tf.parse_example(video_out, features= tf.feature_column.make_parse_example_spec(columns))


    #video = features['video_id']
    #video = tf.cast(features['video_id'], tf.int32)
   # indicator
    heihei = tf.feature_column.input_layer(features,columns)
    #video = features

     # Initialize all global and local variables
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())


    sess.run(init_op)

    # Create a coordinator and run all QueueRunner objects
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    for batch_index in range(3):
        #print video
        x= sess.run([heihei])
        print x

    # Stop the threads
    coord.request_stop()

    # Wait for threads to stop
    coord.join(threads)
    sess.close()
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章