tensorflow筆記 tfrecord創建及讀取

之前很少仔細看tf的一些基礎api,只要能跑通就過了,最近打算花時間把部分基礎api整理一下,方便以後使用。

簡介

tfrecord是tensorflow訓練模型時比較常用的處理大量數據的格式。簡單來說,一種二進制數據儲存格式,比一次性讀取csv或jpg數據要更快,且佔用更小的內存。

生成tfrecord文件

考慮一個簡單的分類問題數據集,feature是一個1x5的向量,label取值爲0或1

import numpy as np
import tensorflow as tf

#構建一個簡單的分類問題數據集,feature爲一個1x5的隨機向量,label取值爲0或1

#生成10個隨機樣本,其中一半樣本label爲0,另一半爲1
n = 10
size = (n, 5)

x_data = np.random.randint(0, 10, size=size)
y1_data = np.ones((n//2, 1), int)
y2_data = np.zeros((n//2, 1), int)
y_data = np.vstack((y1_data, y2_data))
np.random.shuffle(y_data)
xy_data = np.hstack((x_data,y_data))
#print(xy_data)
'''
[[2 0 0 5 8 1]
 [8 3 7 5 1 1]
 [3 5 7 8 7 1]
 [5 2 7 9 9 0]
 [0 1 0 3 0 0]
 [0 3 4 2 5 0]
 [4 8 8 3 8 1]
 [3 5 2 7 7 0]
 [0 4 7 7 3 1]
 [5 0 2 4 9 0]]
'''

#儲存爲tfrecord格式,文件名以.record爲後綴
tfrecord_path = 'data.record'
writer = tf.python_io.TFRecordWriter(tfrecord_path)
for i in range(n):
	#讀入的數據需要先轉化爲list
    sample = x_data[i] 
    label = int(y_data[i])
    example = tf.train.Example(features=tf.train.Features(feature={
        'sample':
            tf.train.Feature(int64_list=tf.train.Int64List(value=sample)),
        'label':
            tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
    }))
    writer.write(example.SerializeToString())
    #print(example)
    #print(example.SerializeToString())
writer.close()

'''
example格式:
features {
  feature {
    key: "label"
    value {
      int64_list {
        value: 1
      }
    }
  }
  feature {
    key: "sample"
    value {
      int64_list {
        value: 2
        value: 0
        value: 0
        value: 5
        value: 8
      }
    }
  }
}
'''

'''
最後存入的SerializeToString()格式:
b'\n%\n\x13\n\x06sample\x12\t\x1a\x07\n\x05\x02\x00\x00\x05\x08\n\x0e\n\x05label\x12\x05\x1a\x03\n\x01\x01'
'''

用parse_single_example讀取tfrecord文件

#讀取tfrecord文件
input_filename = 'data.record'

#建立文件名隊列
filename_queue = tf.train.string_input_producer([input_filename], num_epochs=3)

# 建立閱讀器
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)

# 根據數據格式建立相對應的讀取features
features = {
    'sample': tf.FixedLenFeature([5], tf.int64),# 如果不是標量,一定要在這裏說明數組的長度
    'label': tf.FixedLenFeature([], tf.int64)
}
# 解析單個EXAMPLE
Features = tf.parse_single_example(serialized_example, features)
sample = tf.cast(Features['sample'], tf.float32)
label = tf.cast(Features['label'], tf.float32)

sample_single, label_single = tf.train.batch([sample, label],
                                                batch_size=2, #兩個數據爲一個batch
                                                capacity=200,#隊列最大容量
                                                num_threads=1,
                                                enqueue_many = False #爲Ture時會按batch_size長度截斷
                                            	)

'''
sample_single, label_single = tf.train.shuffle_batch([sample, label],
                                                 batch_size=2,
                                                 capacity=200,
                                                 min_after_dequeue=100,#多少數據之後開始shuffle
                                                 num_threads=1,
                                                 enqueue_many=if_enq_many)
'''

print(sample_single, label_single)

init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    #如果tf.train.string_input_producer([tfrecord_path], num_epochs=3)中num_epochs不爲空,必須要初始化local變量
    sess.run(tf.local_variables_initializer())
    coord = tf.train.Coordinator()  # 管理線程
    threads = tf.train.start_queue_runners(coord=coord)  # 文件名開始進入文件名隊列和內存
    for i in range(1):
        # Queue + tf.parse_single_example()讀取tfrecord文件
        X1, Y1 = sess.run([sample_single, label_single])
        print('X1: ', X1, 'Y1: ', Y1) # 這裏就可以得到tensor具體的數值
    coord.request_stop()
    coord.join(threads)

'''
batch():
Tensor("batch:0", shape=(2, 5), dtype=float32) Tensor("batch:1", shape=(2,), dtype=float32)
X1:  [[2. 0. 0. 5. 8.]
 [8. 3. 7. 5. 1.]] Y1:  [1. 1.]

shuffle_batch():
Tensor("shuffle_batch:0", shape=(2, 5), dtype=float32) Tensor("shuffle_batch:1", shape=(2,), dtype=float32)
X1:  [[2. 0. 0. 5. 8.]
 [0. 3. 4. 2. 5.]] Y1:  [1. 0.]
 '''

用parse_example讀取tfrecord文件

'''
用tf.parse_example()批量讀取數據,據說比tf.parse_single_exaple()讀取數據的速度快(沒有驗證)
args:
      filename_queue: 文件名隊列
      shuffle_batch: 是否批量讀取數據
      if_enq_many: batch時enqueue_many參數的設定,這裏主要用於評估該參數的作用
'''
# 第一步: 建立文件名隊列
input_filename = 'data.record'

filename_queue = tf.train.string_input_producer([input_filename])

reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)


batch = tf.train.shuffle_batch([serialized_example],
                        batch_size=3,
                        capacity=10000,
                        min_after_dequeue=1000,
                        num_threads=1,
                        enqueue_many=False)


'''
batch = tf.train.batch([serialized_example],
                        batch_size=3,
                        capacity=10000,
                        num_threads=1,
                        enqueue_many=False)
'''

features = {
    'sample': tf.FixedLenFeature([5], tf.int64),  # 如果不是標量,一定要在這裏說明數組的長度
    'label': tf.FixedLenFeature([], tf.int64)
}

Features = tf.parse_example(batch, features)

samples_parse = tf.cast(Features['sample'], tf.float32)
labels_parse = tf.cast(Features['label'], tf.float32)

print(samples_parse, labels_parse)

init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)  # 初始化
    coord = tf.train.Coordinator()  # 管理線程
    threads = tf.train.start_queue_runners(coord=coord)  # 文件名開始進入文件名隊列和內存
    for i in range(1):
        X2, Y2 = sess.run([samples_parse, labels_parse])
        print('X2: ', X2, 'Y2: ', Y2)

    coord.request_stop()
    coord.join(threads)

'''
shuffle_batch():
Tensor("Cast_6:0", shape=(3, 5), dtype=float32) Tensor("Cast_7:0", shape=(3,), dtype=float32)
X2:  [[2. 7. 6. 8. 1.]
 [6. 5. 6. 9. 6.]
 [8. 4. 8. 2. 2.]] Y2:  [0. 0. 0.]

batch():
Tensor("Cast_8:0", shape=(3, 5), dtype=float32) Tensor("Cast_9:0", shape=(3,), dtype=float32)
X2:  [[2. 0. 0. 5. 8.]
 [8. 3. 7. 5. 1.]
 [3. 5. 7. 8. 7.]] Y2:  [1. 1. 1.]
'''

注意如果將sparse_single_example和sparse_example放在同一jupyter notebook文件中運行,會報錯:FIFOQueue '_27_batch_1/fifo_queue' is closed and has insufficient elements (requested 3, current size 0)
原因是filename_queue是一個local variable,將init改爲:init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())即可

參考

TF官方文檔
簡書:Tensorflow 數據讀取

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章