tfrecord文件應用流程

對於輸入數據的處理,大體上流程都差不多,可以歸結如下

  1. 將數據轉爲 TFRecord 格式的多個文件
  2. 用 tf.train.match_filenames_once() 創建文件列表
  3. 用 tf.train.string_input_producer() 創建輸入文件隊列,可以將輸入文件順序隨機打亂
  4. 用 tf.TFRecordReader() 讀取文件中的數據
  5. 用 tf.parse_single_example() 解析數據
  6. 對數據進行解碼及預處理
  7. 用 tf.train.shuffle_batch() 將數據組合成 batch
  8. 將 batch 用於訓練

輸入數據處理框架


框架主要是三方面的內容

  • TFRecord 輸入數據格式
  • 圖像數據處理
  • 多線程輸入數據處理

以下代碼只是描繪了一個輸入數據處理的框架,需要根據實際使用環境進行修改(代碼實現自《TensorFlow:實戰Google深度學習框架》

import tensorflow as tf

# 創建文件列表
files = tf.train.match_filenames_once('data/data.tfrecords-*')

# 創建輸入文件隊列
filename_queue = tf.train.string_input_producer(files, shuffle=Flase)

# 解析數據。假設image是圖像數據,label是標籤,height、width、channels給出了圖片的維度
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
    serialized_example,
    features={
        'image': tf.FixedLenFeature([], tf.string),
        'label': tf.FixedLenFeature([], tf.int64),
        'height': tf.FixedLenFeature([], tf.int64),
        'width': tf.FixedLenFeature([], tf.int64),
        'channels': tf.FixedLenFeature([], tf.int64)
    })
image, label = features['image'], features['label']
height, width = tf.cast(features['height'], tf.int32), tf.cast(features['width'], tf.int32)
channels = tf.cast(features['channels'], tf.int32)

# 從原始圖像中解析出像素矩陣,並還原圖像
decoded_image = tf.decode_raw(image, tf.uint8)
decoded_image.set_shape([height, width, channels])

# 定義神經網絡輸入層圖片的大小
image_size = 299

# preprocess_for_train函數是對圖片進行預處理的函數
distorted_image = preprocess_for_train(decoded_image, image_size, image_size,
                                       None)

# 組合成batch
min_after_dequeue = 10000
batch_size = 100
capacity = min_after_dequeue + 3 * batch_size
image_batch, label_batch = tf.train.shuffle_batch(
    [distorted_image, label],
    batch_size=batch_size,
    capacity=capacity,
    min_after_dequeue=min_after_dequeue)

# 定義神經網絡的結構及優化過程
logit = inference(image_batch)
loss = calc_loss(logit, label_batch)
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)

with tf.Session() as sess:
    sess.run(
        [tf.global_variables_initializer(),
         tf.local_variables_initializer()])
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    # 神經網絡訓練過程
    for i in range(TRAINING_ROUNDS):
        sess.run(train_step)

    coord.request_stop()
    coord.join()
發佈了11 篇原創文章 · 獲贊 38 · 訪問量 9萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章