tensorflow tfrecord文件生成,網絡輸入管道

tensorflow tfrecord文件生成,網絡輸入管道

標籤(空格分隔): tensorflow 源碼


在醫學圖像中,不像自然圖像那樣是規整的3通道8位數據,不同的醫學影像有不同的醫學存儲格式,以本小碩的課題來說,醫學圖像數據類型爲爲float32。之前爲了保證數據的原始性,一直不敢存儲爲png、bmp那樣的數據格式,而是存儲爲numpy的npz格式。

但是,對於tensorflow來說,如果採用npz存儲的話,需要一次性將數據全部讀入內存,這樣一是讀取速度特別慢;而是浪費內存。最終,本小碩還是試圖轉成tfrecord標準文件,採用tensorflow自帶的數據流圖。

轉換代碼:

import os
import sys
import numpy as np
import math
import tensorflow as tf
#import build_data

def covert_bin2tfrecord(data_dir,num_shards,save_path):

    #讀取原始數據
    X=np.load(os.path.join(data_dir,'data.npy'))
    Y=np.load(os.path.join(data_dir,'label.npy'))

    num_slices=X.shape[0]
    num_per_shard=int(math.ceil(num_slices/float(num_shards)))
    for shard_id in xrange(num_shards):
        output_filename=os.path.join(save_path,'%s-%05d-of-%05d.tfrecord' %(data_dir.split('/')[-1],shard_id,num_shards))
        with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
            start_idx=shard_id * num_per_shard
            end_idx = min((shard_id+1)*num_per_shard,num_slices)
            for i in xrange(start_idx,end_idx):
                sys.stdout.write('\r>> Converting image %d/%d shard %d' % (
                    i + 1, num_slices, shard_id))
                sys.stdout.flush()
                height,width = X.shape[2],X.shape[3]
                image_data = tf.compat.as_bytes(X[i,...].tostring())
                gt_data = tf.compat.as_bytes(Y[i,...].tostring())
                example = tf.train.Example(features=tf.train.Features(feature={
                    'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data])),
                    'image/height': tf.train.Feature(int64_list=tf.train.Int64List(value=[height])),
                    'image/width': tf.train.Feature(int64_list=tf.train.Int64List(value=[width])),
                    'image/channels': tf.train.Feature(int64_list=tf.train.Int64List(value=[4])),
                    'image/segmentation/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[gt_data])),
                    'image/segmentation/format': tf.train.Feature(bytes_list=tf.train.BytesList(value=b'png')),
                }))
                tfrecord_writer.write(example.SerializeToString())
            sys.stdout.write('\n')
            sys.stdout.flush()

if __name__=='__main__':
    covert_bin2tfrecord('/',5,'/') #訓練集
    covert_bin2tfrecord('/', 1, '/') #測試集

保存爲tfrecord文件後,爲了以防萬一,我們還是要可視化一下數據是否改變:

import tensorflow as tf
import numpy as np
from skimage import io
#from skimage import io
from glob import glob
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
def read_tfrecord(tfrecords_filename):
    if not isinstance(tfrecords_filename, list):
        tfrecords_filename = [tfrecords_filename]
    filename_queue = tf.train.string_input_producer(
        tfrecords_filename)

    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(
        serialized_example,
        features={
            'image/encoded': tf.FixedLenFeature([], tf.string),
            'image/segmentation/encoded': tf.FixedLenFeature([], tf.string),
        })
    image =tf.decode_raw(features['image/encoded'],tf.float32)
    gt_mask =tf.decode_raw(features['image/segmentation/encoded'],tf.uint8)
    image=tf.reshape(image,[6,320,320])
    return image, gt_mask


if __name__=='__main__':
    files=glob('/train*')
    with tf.Session() as sess:
        #image,gt=read_tfrecord(files)
        #建立文件流圖
        filename_queue = tf.train.string_input_producer(files)

        #建立讀取隊列
        reader = tf.TFRecordReader()
        _, serialized_example = reader.read(filename_queue)
        features = tf.parse_single_example(
            serialized_example,
            features={
                'image/encoded': tf.FixedLenFeature([], tf.string),
                'image/segmentation/encoded': tf.FixedLenFeature([], tf.string)
            })
        # image = tf.decode_raw(features['image/encoded'], tf.float32)

        #進行格式轉換 將 tf.string 轉化成 tf.uint8 和 tf.float32
        image = tf.decode_raw(features['image/encoded'],tf.float32)
        image = tf.reshape(image,(6,320,320))
        gt_mask = tf.decode_raw(features['image/segmentation/encoded'],tf.uint8)
        gt_mask = tf.reshape(gt_mask,(320,320))
        #讀取隊列圖
        image_batch,gt_batch  = tf.train.shuffle_batch([image,gt_mask], batch_size=256,capacity=30, min_after_dequeue=20, num_threads=1)
        #init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
        # 初始化圖的全局和局部變量
        sess.run(tf.local_variables_initializer())
        sess.run(tf.global_variables_initializer())

        # 線程管理
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
    # tf.train.start_queue_runners(sess=sess)
        ib,gb=sess.run([image_batch,gt_batch])
        print(ib.shape)
        print(gb.shape)
        data=np.zeros((81920,1280),dtype=np.float32)
        for i in xrange(64):
            for j in xrange(4):
                data[i*320:(i+1)*320,j*320:(j+1)*320]=ib[i,j,...]

        #可視化
        io.imsave('vis_tfrecord.png',data)
        coord.request_stop()
        coord.join(threads)
        # data=np.concatenate([i,g],axis=2)

例子就不方便展示了

發佈了46 篇原創文章 · 獲贊 34 · 訪問量 7萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章