tensorflow tfrecord文件生成,網絡輸入管道
標籤(空格分隔): tensorflow 源碼
在醫學圖像中,不像自然圖像那樣是規整的3通道8位數據,不同的醫學影像有不同的醫學存儲格式,以本小碩的課題來說,醫學圖像數據類型爲爲float32。之前爲了保證數據的原始性,一直不敢存儲爲png、bmp那樣的數據格式,而是存儲爲numpy的npz格式。
但是,對於tensorflow來說,如果採用npz存儲的話,需要一次性將數據全部讀入內存,這樣一是讀取速度特別慢;而是浪費內存。最終,本小碩還是試圖轉成tfrecord標準文件,採用tensorflow自帶的數據流圖。
轉換代碼:
import os
import sys
import numpy as np
import math
import tensorflow as tf
#import build_data
def covert_bin2tfrecord(data_dir,num_shards,save_path):
#讀取原始數據
X=np.load(os.path.join(data_dir,'data.npy'))
Y=np.load(os.path.join(data_dir,'label.npy'))
num_slices=X.shape[0]
num_per_shard=int(math.ceil(num_slices/float(num_shards)))
for shard_id in xrange(num_shards):
output_filename=os.path.join(save_path,'%s-%05d-of-%05d.tfrecord' %(data_dir.split('/')[-1],shard_id,num_shards))
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
start_idx=shard_id * num_per_shard
end_idx = min((shard_id+1)*num_per_shard,num_slices)
for i in xrange(start_idx,end_idx):
sys.stdout.write('\r>> Converting image %d/%d shard %d' % (
i + 1, num_slices, shard_id))
sys.stdout.flush()
height,width = X.shape[2],X.shape[3]
image_data = tf.compat.as_bytes(X[i,...].tostring())
gt_data = tf.compat.as_bytes(Y[i,...].tostring())
example = tf.train.Example(features=tf.train.Features(feature={
'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data])),
'image/height': tf.train.Feature(int64_list=tf.train.Int64List(value=[height])),
'image/width': tf.train.Feature(int64_list=tf.train.Int64List(value=[width])),
'image/channels': tf.train.Feature(int64_list=tf.train.Int64List(value=[4])),
'image/segmentation/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[gt_data])),
'image/segmentation/format': tf.train.Feature(bytes_list=tf.train.BytesList(value=b'png')),
}))
tfrecord_writer.write(example.SerializeToString())
sys.stdout.write('\n')
sys.stdout.flush()
if __name__=='__main__':
covert_bin2tfrecord('/',5,'/') #訓練集
covert_bin2tfrecord('/', 1, '/') #測試集
保存爲tfrecord文件後,爲了以防萬一,我們還是要可視化一下數據是否改變:
import tensorflow as tf
import numpy as np
from skimage import io
#from skimage import io
from glob import glob
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
def read_tfrecord(tfrecords_filename):
if not isinstance(tfrecords_filename, list):
tfrecords_filename = [tfrecords_filename]
filename_queue = tf.train.string_input_producer(
tfrecords_filename)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'image/encoded': tf.FixedLenFeature([], tf.string),
'image/segmentation/encoded': tf.FixedLenFeature([], tf.string),
})
image =tf.decode_raw(features['image/encoded'],tf.float32)
gt_mask =tf.decode_raw(features['image/segmentation/encoded'],tf.uint8)
image=tf.reshape(image,[6,320,320])
return image, gt_mask
if __name__=='__main__':
files=glob('/train*')
with tf.Session() as sess:
#image,gt=read_tfrecord(files)
#建立文件流圖
filename_queue = tf.train.string_input_producer(files)
#建立讀取隊列
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'image/encoded': tf.FixedLenFeature([], tf.string),
'image/segmentation/encoded': tf.FixedLenFeature([], tf.string)
})
# image = tf.decode_raw(features['image/encoded'], tf.float32)
#進行格式轉換 將 tf.string 轉化成 tf.uint8 和 tf.float32
image = tf.decode_raw(features['image/encoded'],tf.float32)
image = tf.reshape(image,(6,320,320))
gt_mask = tf.decode_raw(features['image/segmentation/encoded'],tf.uint8)
gt_mask = tf.reshape(gt_mask,(320,320))
#讀取隊列圖
image_batch,gt_batch = tf.train.shuffle_batch([image,gt_mask], batch_size=256,capacity=30, min_after_dequeue=20, num_threads=1)
#init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
# 初始化圖的全局和局部變量
sess.run(tf.local_variables_initializer())
sess.run(tf.global_variables_initializer())
# 線程管理
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
# tf.train.start_queue_runners(sess=sess)
ib,gb=sess.run([image_batch,gt_batch])
print(ib.shape)
print(gb.shape)
data=np.zeros((81920,1280),dtype=np.float32)
for i in xrange(64):
for j in xrange(4):
data[i*320:(i+1)*320,j*320:(j+1)*320]=ib[i,j,...]
#可視化
io.imsave('vis_tfrecord.png',data)
coord.request_stop()
coord.join(threads)
# data=np.concatenate([i,g],axis=2)
例子就不方便展示了