簡介
本文介紹TensorFlow的第二種數據導入方法。
爲了保持高效,這種方法稍顯繁瑣。分爲如下幾個步驟:
- 把所有樣本寫入二進制文件(只執行一次)
- 創建Tensor
,從二進制文件讀取一個樣本
- 創建Tensor
,從二進制文件隨機讀取一個mini-batch
- 把mini-batchTensor
傳入網絡作爲輸入節點。
二進制文件
使用tf.python_io.TFRecordWriter
創建一個專門存儲tensorflow數據的writer
,擴展名爲’.tfrecord’。
該文件中依次存儲着序列化的tf.train.Example
類型的樣本。
writer = tf.python_io.TFRecordWriter('/tmp/data.tfrecord')
for i in range(0, 10):
# 創建樣本example
# ...
serialized = example.SerializeToString() # 序列化
writer.write(serialized) # 寫入文件
writer.close()
每一個example
的feature
成員變量是一個dict
,存儲一個樣本的不同部分(例如圖像像素+類標)。以下例子的樣本中包含三個鍵a,b,c
:
# 創建樣本example
a_data = 0.618 + i # float
b_data = [2016 + i, 2017+i] # int64
c_data = numpy.array([[0, 1, 2],[3, 4, 5]]) + i # bytes
c_data = c_data.astype(numpy.uint8)
c_raw = c.tostring() # 轉化成字符串
example = tf.train.Example(
features=tf.train.Features(
feature={
'a': tf.train.Feature(
float_list=tf.train.FloatList(value=[a_data]) # 方括號表示輸入爲list
),
'b': tf.train.Feature(
int64_list=tf.train.Int64List(value=b_data) # b_data本身就是列表
),
'c': tf.train.Feature(
bytes_list=tf.train.BytesList(value=[c_raw])
)
}
)
)
dict
成員的值部分接受三種類型數據:
- tf.train.FloatList
:列表每個元素爲float。例如a
。
- tf.train.Int64List
:列表每個元素爲int64。例如b
。
- tf.train.BytesList
:列表每個元素爲string。例如c
。
第三種類型尤其適合圖像樣本。注意在轉成字符串之前要設定爲uint8
類型。
讀取一個樣本
接下來,我們定義一個函數,創建“從文件中讀一個樣本”操作,返回結果Tensor
。
def read_single_sample(filename):
# 讀取樣本example的每個成員a,b,c
# ...
return a, b, c
首先創建讀文件隊列,使用tf.TFRecordReader
從文件隊列讀入一個序列化的樣本。
# 讀取樣本example的每個成員a,b,c
filename_queue = tf.train.string_input_producer([filename], num_epochs=None) # 不限定讀取數量
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
如果樣本量很大,可以分成若干文件,把文件名列表傳入tf.train.string_input_producer
。
和剛纔的writer不同,這個reader是符號化的,只有在sess中run纔會執行。
接下來解析符號化的樣本
# get feature from serialized example
features = tf.parse_single_example(
serialized_example,
features={
'a': tf.FixedLenFeature([], tf.float32), #0D, 標量
'b': tf.FixedLenFeature([2], tf.int64), # 1D,長度爲2
'c': tf.FixedLenFeature([], tf.string) # 0D, 標量
}
)
a = features['a']
b = features['b']
c_raw = features['c']
c = tf.decode_raw(c_raw, tf.uint8)
c = tf.reshape(c, [2, 3])
對於BytesList
,要重新進行解碼,把string
類型的0維Tensor
變成uint8
類型的1維Tensor
。
讀取mini-batch
使用tf.train.shuffle_batch
將前述a,b,c
隨機化,獲得mini-batchTensor
:
a_batch, b_batch, c_batch = tf.train.shuffle_batch([a, b, c], batch_size=2, capacity=200, min_after_dequeue=100, num_threads=2)
使用
創建一個session
並初始化:
# sess
sess = tf.Session()
init = tf.initialize_all_variables()
sess.run(init)
tf.train.start_queue_runners(sess=sess)
由於使用了讀文件隊列,所以要start_queue_runners
。
每一次運行,會隨機生成一個mini-batch樣本:
a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch])
a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch])
這樣的mini-batch可以作爲網絡的輸入節點使用。
總結
如果想進一步瞭解例子中的隊列機制,請參看這篇文章。
本文參考了以下示例:
https://github.com/mnuke/tf-slim-mnist
https://indico.io/blog/tensorflow-data-inputs-part1-placeholders-protobufs-queues/
https://github.com/tensorflow/tensorflow/tree/r0.11/tensorflow/models/image/cifar10
完整代碼如下:
import tensorflow as tf
import numpy
def write_binary():
writer = tf.python_io.TFRecordWriter('/tmp/data.tfrecord')
for i in range(0, 2):
a = 0.618 + i
b = [2016 + i, 2017+i]
c = numpy.array([[0, 1, 2],[3, 4, 5]]) + i
c = c.astype(numpy.uint8)
c_raw = c.tostring()
example = tf.train.Example(
features=tf.train.Features(
feature={
'a': tf.train.Feature(
float_list=tf.train.FloatList(value=[a])
),
'b': tf.train.Feature(
int64_list=tf.train.Int64List(value=b)
),
'c': tf.train.Feature(
bytes_list=tf.train.BytesList(value=[c_raw])
)
}
)
)
serialized = example.SerializeToString()
writer.write(serialized)
writer.close()
def read_single_sample(filename):
# output file name string to a queue
filename_queue = tf.train.string_input_producer([filename], num_epochs=None)
# create a reader from file queue
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
# get feature from serialized example
features = tf.parse_single_example(
serialized_example,
features={
'a': tf.FixedLenFeature([], tf.float32),
'b': tf.FixedLenFeature([2], tf.int64),
'c': tf.FixedLenFeature([], tf.string)
}
)
a = features['a']
b = features['b']
c_raw = features['c']
c = tf.decode_raw(c_raw, tf.uint8)
c = tf.reshape(c, [2, 3])
return a, b, c
#-----main function-----
if 1:
write_binary()
else:
# create tensor
a, b, c = read_single_sample('/tmp/data.tfrecord')
a_batch, b_batch, c_batch = tf.train.shuffle_batch([a, b, c], batch_size=3, capacity=200, min_after_dequeue=100, num_threads=2)
queues = tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS)
# sess
sess = tf.Session()
init = tf.initialize_all_variables()
sess.run(init)
tf.train.start_queue_runners(sess=sess)
a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch])
print(a_val, b_val, c_val)
a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch])
print(a_val, b_val, c_val)