tensorflow之數據讀取 -- 用tf.data通過tfrecord讀取數據或者直接讀取數據

對於數據量很大的數據集, 直接讀入內存可能會放不下, 建議的做法是把全部數據轉換成tfrecord的格式, 方便神經網絡讀取數據, 並且從tfrecord中讀取數據的話tensorflow專門做過優化, 能加快讀取速度.

參考資料: 官方tfrecord讀寫教程

1. 生成tfrecord

方法1: 直接以二進制bytes讀取圖片, 然後放進tfrecord中, 但是這樣對bytes沒法做修改, 比如有時候label需要進行map, 這時候就要用方法2.

import tensorflow as tf

# 把一個byte數據轉換成一個bytes_list
def _bytes_list_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

# 把一對features和label轉換成一個tfexample
def image_seg_to_tfexample(image_data, seg_data):
  return tf.train.Example(features=tf.train.Features(feature={
      'image': _bytes_list_feature(image_data),
      'label': _bytes_list_feature(seg_data),
  }))

# 解析並讀取image和label成二進制byte類型, image_data = open(image_filename, 'rb').read()有相同的效果
image_data = tf.gfile.GFile(image_filename, 'rb').read() # type(image_data)爲bytes
seg_data = tf.gfile.GFile(seg_filename, 'rb').read() 
# image_data = tf.read_file(image_filename) 也行, type(image_data)也是bytes

with tf.python_io.TFRecordWriter(output_filename) as writer:
  example = image_seg_to_tfexample(image_data,seg_data)
  # 把tfexample寫入tfrecord中
  writer.write(example.SerializeToString())

方法2: 不直接把圖片讀取成bytes, 而是轉換成ndarray, 這樣可以對ndarray進行修改, 再寫入tfrecord中.

from PIL import Image
import numpy as np
import tensorflow as tf

  # 讀取已經保存好的字典, 後面用於map
with open('/home/steven/deeplab_v3+_project/deeplab_v3+_tensorflow_from_rishizek/map_dictionary.pickle', 'rb') as f:
  map_dict = pickle.load(f)

# 讀取image成ndarray,注意讀取的時候dtype設置爲np.uint8, 因爲像素值在0-255之間
image_data = np.array(Image.open(image_filename)).astype(np.uint8)
# 將image從ndarray變成bytes, 方便寫入tfrecord
image_data = image_data.tostring()
  
# 讀取label成ndarray,先不轉換np.uint8, 因爲map可能改變dtype
seg_data = np.array(Image.open(seg_filename))
# 對ndarray做map
seg_data_mapped = np.vectorize(map_dict.get)(seg_data)
# 將seg_data_mapped從ndarray變成bytes, 方便寫入tfrecord, 注意先把數據也轉換成np.uint8再變成bytes
seg_data = seg_data_mapped.astype(np.uint8).tostring()
  
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
  # image_seg_to_tfexample()函數見方法1
  example = image_seg_to_tfexample(image_data,seg_data)
  tfrecord_writer.write(example.SerializeToString())

2. 讀取tfrecord: tf.data.TFRecordDataset()

# 返回一個list, 包含所有要輸入的tfrecord文件
def get_filenames(is_training, data_dir):
    if is_training:
        return [os.path.join(data_dir, 'nonzeros_train-00000-of-00004.tfrecord'),
             os.path.join(data_dir, 'nonzeros_train-00001-of-00004.tfrecord'),
             os.path.join(data_dir, 'nonzeros_train-00002-of-00004.tfrecord'),
             os.path.join(data_dir, 'nonzeros_train-00003-of-00004.tfrecord')]
    else:
        return [os.path.join(data_dir, 'nonzeros_valid-00000-of-00004.tfrecord'),
             os.path.join(data_dir, 'nonzeros_valid-00001-of-00004.tfrecord'),
             os.path.join(data_dir, 'nonzeros_valid-00002-of-00004.tfrecord'),
             os.path.join(data_dir, 'nonzeros_valid-00003-of-00004.tfrecord')]

# 讀取所有tfrecord文件得到dataset
dataset = tf.data.TFRecordDataset(get_filenames(is_training,data_dir))

# 解析dataset的函數, 直接把bytes轉換回image, 對應方法1
def parse_record(raw_record):
	# 按什麼格式寫入的, 就要以同樣的格式輸出
	keys_to_features = {
      'image': tf.FixedLenFeature((), tf.string),
      'label': tf.FixedLenFeature((), tf.string),
    }
	# 按照keys_to_features解析二進制的
    parsed = tf.parse_single_example(raw_record, keys_to_features)
    
    image = tf.image.decode_image(tf.reshape(parsed['image'], shape=[]), 1)
    image = tf.to_float(tf.image.convert_image_dtype(image, dtype=tf.uint8))
    image.set_shape([None, None, 1])
    label = tf.image.decode_image(tf.reshape(parsed['label'], shape=[]), 1)
    label = tf.to_int32(tf.image.convert_image_dtype(label, dtype=tf.uint8))
    label.set_shape([None, None, 1])
    
    return image, label

# 直接把bytes類型的ndarray解析回來, 用decode_raw(),對應方法2
def parse_record(raw_record):
    keys_to_features = {
      'image': tf.FixedLenFeature((), tf.string),
      'label': tf.FixedLenFeature((), tf.string),
    }
    parsed = tf.parse_single_example(raw_record, keys_to_features)

    image = tf.decode_raw(parsed['image'], tf.uint8)
    image = tf.to_float(image)
    image = tf.reshape(image, [256,256,1])
    label = tf.decode_raw(parsed['label'], tf.uint8)
    label = tf.to_int32(label)
    label = tf.reshape(label, [256,256,1])

    return image, label

# 對dataset中的每條數據, 應用parse_record函數, 得到解析後的新的dataset
dataset = dataset.map(parse_record)
# 對dataset中的每條數據, 應用lambda函數, 輸入image, label, 用preprocess_image()函數(省略沒寫)處理,得到新的dataset
dataset = dataset.map(lambda image, label: preprocess_image(image, label, is_training))
# dataset還可以做repeat(), shuffle(), batch()等處理
dataset = dataset.shuffle(buffer_size).repeat(num_epochs).batch(batch_size)
# 每次sess.run(images, labels)得到一個batch_size的images和labels
iterator = dataset.make_one_shot_iterator()
images, labels = iterator.get_next()

3. 通過tf.data.Dataset直接讀取數據

def eval_input_fn(image_filenames, label_filenames=None, batch_size=1):
  # Reads an image from a file, decodes it into a dense tensor
  def _parse_function(filename, is_label):
    if not is_label:
      image_filename, label_filename = filename, None
    else:
      image_filename, label_filename = filename
    
    # 這裏與上面1和2中讀取與解析圖片的過程類似, 區別在於tf.readfile()得到的bytes文件沒有放入tfrecord中, 而是通過tf.image.decode_image()直接解析成tensor
    image_string = tf.read_file(image_filename)
    image = tf.image.decode_image(image_string)
    image = tf.to_float(tf.image.convert_image_dtype(image, dtype=tf.uint8))
    image.set_shape([None, None, 3])

    if not is_label:
      return image
    else:
      # 讀取與解析label
      label_string = tf.read_file(label_filename)
      label = tf.image.decode_image(label_string)
      label = tf.to_int32(tf.image.convert_image_dtype(label, dtype=tf.uint8))
      label.set_shape([None, None, 1])

      return image, label
      
  if label_filenames is None:
    input_filenames = image_filenames
  else:
    input_filenames = (image_filenames, label_filenames)
  
  # input_filenames是一個文件名組成的list或者一個由兩個list組成的元組, 這裏通過tf.data.Dataset.from_tensor_slices()直接獲得文件名組成的dataset
  dataset = tf.data.Dataset.from_tensor_slices(input_filenames)
  # 通過map函數, 解析dataset中的文件名形成一個新的dataset
  if label_filenames is None:
    dataset = dataset.map(lambda x: _parse_function(x, False))
  else:
    dataset = dataset.map(lambda x, y: _parse_function((x, y), True))
  dataset = dataset.prefetch(batch_size)
  dataset = dataset.batch(batch_size)
  iterator = dataset.make_one_shot_iterator()

  if label_filenames is None:
    images = iterator.get_next()
    labels = None
  else:
    images, labels = iterator.get_next()

  return images, labels

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章