對於數據量很大的數據集, 直接讀入內存可能會放不下, 建議的做法是把全部數據轉換成tfrecord的格式, 方便神經網絡讀取數據, 並且從tfrecord中讀取數據的話tensorflow專門做過優化, 能加快讀取速度.
參考資料: 官方tfrecord讀寫教程
1. 生成tfrecord
方法1: 直接以二進制bytes讀取圖片, 然後放進tfrecord中, 但是這樣對bytes沒法做修改, 比如有時候label需要進行map, 這時候就要用方法2.
import tensorflow as tf
# 把一個byte數據轉換成一個bytes_list
def _bytes_list_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
# 把一對features和label轉換成一個tfexample
def image_seg_to_tfexample(image_data, seg_data):
return tf.train.Example(features=tf.train.Features(feature={
'image': _bytes_list_feature(image_data),
'label': _bytes_list_feature(seg_data),
}))
# 解析並讀取image和label成二進制byte類型, image_data = open(image_filename, 'rb').read()有相同的效果
image_data = tf.gfile.GFile(image_filename, 'rb').read() # type(image_data)爲bytes
seg_data = tf.gfile.GFile(seg_filename, 'rb').read()
# image_data = tf.read_file(image_filename) 也行, type(image_data)也是bytes
with tf.python_io.TFRecordWriter(output_filename) as writer:
example = image_seg_to_tfexample(image_data,seg_data)
# 把tfexample寫入tfrecord中
writer.write(example.SerializeToString())
方法2: 不直接把圖片讀取成bytes, 而是轉換成ndarray, 這樣可以對ndarray進行修改, 再寫入tfrecord中.
from PIL import Image
import numpy as np
import tensorflow as tf
# 讀取已經保存好的字典, 後面用於map
with open('/home/steven/deeplab_v3+_project/deeplab_v3+_tensorflow_from_rishizek/map_dictionary.pickle', 'rb') as f:
map_dict = pickle.load(f)
# 讀取image成ndarray,注意讀取的時候dtype設置爲np.uint8, 因爲像素值在0-255之間
image_data = np.array(Image.open(image_filename)).astype(np.uint8)
# 將image從ndarray變成bytes, 方便寫入tfrecord
image_data = image_data.tostring()
# 讀取label成ndarray,先不轉換np.uint8, 因爲map可能改變dtype
seg_data = np.array(Image.open(seg_filename))
# 對ndarray做map
seg_data_mapped = np.vectorize(map_dict.get)(seg_data)
# 將seg_data_mapped從ndarray變成bytes, 方便寫入tfrecord, 注意先把數據也轉換成np.uint8再變成bytes
seg_data = seg_data_mapped.astype(np.uint8).tostring()
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
# image_seg_to_tfexample()函數見方法1
example = image_seg_to_tfexample(image_data,seg_data)
tfrecord_writer.write(example.SerializeToString())
2. 讀取tfrecord: tf.data.TFRecordDataset()
# 返回一個list, 包含所有要輸入的tfrecord文件
def get_filenames(is_training, data_dir):
if is_training:
return [os.path.join(data_dir, 'nonzeros_train-00000-of-00004.tfrecord'),
os.path.join(data_dir, 'nonzeros_train-00001-of-00004.tfrecord'),
os.path.join(data_dir, 'nonzeros_train-00002-of-00004.tfrecord'),
os.path.join(data_dir, 'nonzeros_train-00003-of-00004.tfrecord')]
else:
return [os.path.join(data_dir, 'nonzeros_valid-00000-of-00004.tfrecord'),
os.path.join(data_dir, 'nonzeros_valid-00001-of-00004.tfrecord'),
os.path.join(data_dir, 'nonzeros_valid-00002-of-00004.tfrecord'),
os.path.join(data_dir, 'nonzeros_valid-00003-of-00004.tfrecord')]
# 讀取所有tfrecord文件得到dataset
dataset = tf.data.TFRecordDataset(get_filenames(is_training,data_dir))
# 解析dataset的函數, 直接把bytes轉換回image, 對應方法1
def parse_record(raw_record):
# 按什麼格式寫入的, 就要以同樣的格式輸出
keys_to_features = {
'image': tf.FixedLenFeature((), tf.string),
'label': tf.FixedLenFeature((), tf.string),
}
# 按照keys_to_features解析二進制的
parsed = tf.parse_single_example(raw_record, keys_to_features)
image = tf.image.decode_image(tf.reshape(parsed['image'], shape=[]), 1)
image = tf.to_float(tf.image.convert_image_dtype(image, dtype=tf.uint8))
image.set_shape([None, None, 1])
label = tf.image.decode_image(tf.reshape(parsed['label'], shape=[]), 1)
label = tf.to_int32(tf.image.convert_image_dtype(label, dtype=tf.uint8))
label.set_shape([None, None, 1])
return image, label
# 直接把bytes類型的ndarray解析回來, 用decode_raw(),對應方法2
def parse_record(raw_record):
keys_to_features = {
'image': tf.FixedLenFeature((), tf.string),
'label': tf.FixedLenFeature((), tf.string),
}
parsed = tf.parse_single_example(raw_record, keys_to_features)
image = tf.decode_raw(parsed['image'], tf.uint8)
image = tf.to_float(image)
image = tf.reshape(image, [256,256,1])
label = tf.decode_raw(parsed['label'], tf.uint8)
label = tf.to_int32(label)
label = tf.reshape(label, [256,256,1])
return image, label
# 對dataset中的每條數據, 應用parse_record函數, 得到解析後的新的dataset
dataset = dataset.map(parse_record)
# 對dataset中的每條數據, 應用lambda函數, 輸入image, label, 用preprocess_image()函數(省略沒寫)處理,得到新的dataset
dataset = dataset.map(lambda image, label: preprocess_image(image, label, is_training))
# dataset還可以做repeat(), shuffle(), batch()等處理
dataset = dataset.shuffle(buffer_size).repeat(num_epochs).batch(batch_size)
# 每次sess.run(images, labels)得到一個batch_size的images和labels
iterator = dataset.make_one_shot_iterator()
images, labels = iterator.get_next()
3. 通過tf.data.Dataset直接讀取數據
def eval_input_fn(image_filenames, label_filenames=None, batch_size=1):
# Reads an image from a file, decodes it into a dense tensor
def _parse_function(filename, is_label):
if not is_label:
image_filename, label_filename = filename, None
else:
image_filename, label_filename = filename
# 這裏與上面1和2中讀取與解析圖片的過程類似, 區別在於tf.readfile()得到的bytes文件沒有放入tfrecord中, 而是通過tf.image.decode_image()直接解析成tensor
image_string = tf.read_file(image_filename)
image = tf.image.decode_image(image_string)
image = tf.to_float(tf.image.convert_image_dtype(image, dtype=tf.uint8))
image.set_shape([None, None, 3])
if not is_label:
return image
else:
# 讀取與解析label
label_string = tf.read_file(label_filename)
label = tf.image.decode_image(label_string)
label = tf.to_int32(tf.image.convert_image_dtype(label, dtype=tf.uint8))
label.set_shape([None, None, 1])
return image, label
if label_filenames is None:
input_filenames = image_filenames
else:
input_filenames = (image_filenames, label_filenames)
# input_filenames是一個文件名組成的list或者一個由兩個list組成的元組, 這裏通過tf.data.Dataset.from_tensor_slices()直接獲得文件名組成的dataset
dataset = tf.data.Dataset.from_tensor_slices(input_filenames)
# 通過map函數, 解析dataset中的文件名形成一個新的dataset
if label_filenames is None:
dataset = dataset.map(lambda x: _parse_function(x, False))
else:
dataset = dataset.map(lambda x, y: _parse_function((x, y), True))
dataset = dataset.prefetch(batch_size)
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
if label_filenames is None:
images = iterator.get_next()
labels = None
else:
images, labels = iterator.get_next()
return images, labels