對於輸入數據的處理,大體上流程都差不多,可以歸結如下
- 將數據轉爲 TFRecord 格式的多個文件
- 用 tf.train.match_filenames_once() 創建文件列表
- 用 tf.train.string_input_producer() 創建輸入文件隊列,可以將輸入文件順序隨機打亂
- 用 tf.TFRecordReader() 讀取文件中的數據
- 用 tf.parse_single_example() 解析數據
- 對數據進行解碼及預處理
- 用 tf.train.shuffle_batch() 將數據組合成 batch
- 將 batch 用於訓練
輸入數據處理框架
框架主要是三方面的內容
- TFRecord 輸入數據格式
- 圖像數據處理
- 多線程輸入數據處理
以下代碼只是描繪了一個輸入數據處理的框架,需要根據實際使用環境進行修改(代碼實現自《TensorFlow:實戰Google深度學習框架》)
import tensorflow as tf
# 創建文件列表
files = tf.train.match_filenames_once('data/data.tfrecords-*')
# 創建輸入文件隊列
filename_queue = tf.train.string_input_producer(files, shuffle=Flase)
# 解析數據。假設image是圖像數據,label是標籤,height、width、channels給出了圖片的維度
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'channels': tf.FixedLenFeature([], tf.int64)
})
image, label = features['image'], features['label']
height, width = tf.cast(features['height'], tf.int32), tf.cast(features['width'], tf.int32)
channels = tf.cast(features['channels'], tf.int32)
# 從原始圖像中解析出像素矩陣,並還原圖像
decoded_image = tf.decode_raw(image, tf.uint8)
decoded_image.set_shape([height, width, channels])
# 定義神經網絡輸入層圖片的大小
image_size = 299
# preprocess_for_train函數是對圖片進行預處理的函數
distorted_image = preprocess_for_train(decoded_image, image_size, image_size,
None)
# 組合成batch
min_after_dequeue = 10000
batch_size = 100
capacity = min_after_dequeue + 3 * batch_size
image_batch, label_batch = tf.train.shuffle_batch(
[distorted_image, label],
batch_size=batch_size,
capacity=capacity,
min_after_dequeue=min_after_dequeue)
# 定義神經網絡的結構及優化過程
logit = inference(image_batch)
loss = calc_loss(logit, label_batch)
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
with tf.Session() as sess:
sess.run(
[tf.global_variables_initializer(),
tf.local_variables_initializer()])
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
# 神經網絡訓練過程
for i in range(TRAINING_ROUNDS):
sess.run(train_step)
coord.request_stop()
coord.join()