讀取MR圖像爲torch.Tensor格式


from torch.utils.data import Dataset, DataLoader
import numpy as np
import SimpleITK as sitk
import os

class My_MR_Dataset(Dataset):
    """
        下載數據、初始化數據,都可以在這裏完成
    """

    def __init__(self, src_folder, params):
        self.params = params
        self.src_folder = src_folder

        self.loadTrainingData()
        self.images_numpy = self.getNumpyImages()
        self.gt_numpy = self.getNumpyGt()

        assert len(self.images_numpy) == len(self.gt_numpy)

        self.len = len(self.images_numpy)

    def __getitem__(self, index):
        return self.images_numpy[list(self.images_numpy.keys())[index]], self.gt_numpy[list(self.gt_numpy.keys())[index]]

    def __len__(self):
        return self.len

    def get_new_spacing(self, sitkImage, new_size):

        center = np.array(sitkImage.GetOrigin())
        old_size = np.array(sitkImage.GetSize())
        old_spacing = np.array(sitkImage.GetSpacing())
        new_size[-1] = old_size[-1]

        x_rate = 0.45
        y_rate = 0.45

        old_size[0] = old_size[0] * x_rate
        old_size[1] = old_size[1] * y_rate

        new_spacing = ((old_size * old_spacing + center) - center) / new_size
        new_spacing[-1] = old_spacing[-1]

        return new_spacing

    def getNumpyImages(self):
        data = self.getNumpyData(self.sitk_images, sitk.sitkLinear)
        return data

    def getNumpyGt(self):
        data = self.getNumpyData(self.sitk_gt, sitk.sitkLinear)
        for key in data:
            data[key] = (data[key]>0.5).astype(dtype=np.float32)
        return data

    def getNumpyData(self, data, method):
        '''
        # 要與mhd.py中的read_mhd函數共同修改
        :param data: sitk_images
        :param method:
        :return:
        '''

        ret = dict()
        self.sitkImage = dict()
        for key in data:

            ret[key] = np.zeros([self.params['VolSize'][0], self.params['VolSize'][1], self.params['VolSize'][2]],
                                dtype=np.float32)  # 佔位
            self.params['dstRes'] = self.get_new_spacing(data[key], self.params['VolSize'])  # 求新尺寸下的spacing
            img = data[key]
            spacing = img.GetSpacing()
            # one of the image's spacing is 6
            if (spacing[2] == 6):
                spacing = (spacing[0], spacing[1], 3)

            factor = np.asarray(spacing) / [self.params['dstRes'][0], self.params['dstRes'][1],
                                            self.params['dstRes'][2]]  # old spacing和new spacing之比
            factor_size = np.asarray(img.GetSize() * factor, dtype=float)
            new_size = np.max([factor_size, self.params['VolSize']], axis=0)  # 這是先把原圖原封不動地讀出來,然後在原圖的基礎上裁剪
            new_size = new_size.astype(dtype=int)

            T = sitk.AffineTransform(3)  # 仿射變換:相當於對於圖像做了一個平移、旋轉、放縮、剪切、對稱。與剛體變換相同的是,可以保持線點之間的平行和共線關係。
            T.SetMatrix(img.GetDirection())  # ???

            resampler = sitk.ResampleImageFilter()  # Resample an image via a coordinate transform,就是對img的一系列操作
            resampler.SetReferenceImage(img)  # This methods sets the output size, origin, spacing and direction to
            # that of the provided image
            resampler.SetOutputSpacing([self.params['dstRes'][0], self.params['dstRes'][1], self.params['dstRes'][2]])
            resampler.SetSize(new_size.tolist())
            resampler.SetInterpolator(method)  # 線性插值

            if self.params['normDir']:
                resampler.SetTransform(T.GetInverse())  # 圖像翻轉

            '''
            Organ = img.GetOrigin()
            imDepth = img.GetDepth() * img.GetSpacing()[2]
            outputDepth = self.params['dstRes'][2] * self.params['VolSize'][2]

            if imDepth > outputDepth:
                deltaDepth = np.round((imDepth - outputDepth) / 2 / img.GetSpacing()[2]) * img.GetSpacing()[2]
                Organ = [Organ[0], Organ[1], Organ[2] - deltaDepth]
                resampler.SetOutputOrigin(Organ)
            '''
            img_resampled = resampler.Execute(img)

            img_centroid = np.asarray(new_size, dtype=float) / 2.0  # img中心點座標
            img_start_px = (img_centroid - self.params['VolSize'] / 2.0).astype(dtype=int)  # 裁剪框左上角起始點
            img_start_px[1] = img_start_px[1] * 0.5  # 將裁剪框從中央向上移動一點
            # 提取一個感興趣的區域,其實就是裁剪一個區域
            region_extractor = sitk.RegionOfInterestImageFilter()
            region_extractor.SetSize(self.params['VolSize'].tolist())  # 截取region的大小,和VolSize一樣大
            region_extractor.SetIndex(img_start_px.tolist())  # 截取region的起始點
            img_resampled_cropped = region_extractor.Execute(img_resampled)

            self.sitkImage[key] = img_resampled_cropped
            ret[key] = sitk.GetArrayFromImage(img_resampled_cropped).astype(dtype=float)
        return ret

    def loadTrainingData(self):
        self.createImageFileList()
        self.createGTFileList()
        self.loadImages()
        self.loadGT()

    def createGTFileList(self):
        self.gt_list = list()
        for f in self.file_list:
            filename, ext = os.path.splitext(f)
            self.gt_list.append(os.path.join(filename + '_segmentation' + ext))

    def createImageFileList(self):
        self.file_list = [f for f in os.listdir(self.src_folder)
                            if os.path.isfile(os.path.join(self.src_folder, f)) and 'segmentation' not in f and 'raw' not in f ]
        print('File List:' + str(self.file_list))

    def loadImages(self):
        self.sitk_images = dict()
        rescal_filt = sitk.RescaleIntensityImageFilter()
        rescal_filt.SetOutputMaximum(1)
        rescal_filt.SetOutputMinimum(0)

        stats = sitk.StatisticsImageFilter()
        m = 0.
        for f in self.file_list:
            self.sitk_images[f] = rescal_filt.Execute(
                sitk.Cast(sitk.ReadImage(os.path.join(self.src_folder, f)), sitk.sitkFloat32))
            stats.Execute(self.sitk_images[f])
            m += stats.GetMean()

        self.mean_intensity_train = m / len(self.sitk_images)

    def loadGT(self):
        self.sitk_gt = dict()
        for f in self.gt_list:
            self.sitk_gt[f] = sitk.Cast(sitk.ReadImage(os.path.join(self.src_folder, f))>0.5, sitk.sitkFloat32)


if __name__ == '__main__':

    ##########################################################################################
    #
    #   注意
    #   1、此代碼僅僅截取MR圖像中的膀胱部分;
    #   2、batch_size指病人樣本個數。但是由於每個病人樣本的切片數不一樣,
    #   不能將不同切片數的病人放入一個batch中,導致batch size只能設置1。
    #
    ##########################################################################################

    src_folder = '/home/zhangxuchang/laboratory/changv-net/data/test'
    params = dict()
    params['VolSize'] = np.asarray([128, 128, 64], dtype=int)
    params['normDir'] = False  # if rotates the volume according to its transformation in the mhd file. Not reccommended.
                                                        # 即圖像翻轉用的,暫時不考慮

    my_mr_dataset = My_MR_Dataset(src_folder, params)
    train_loader = DataLoader(dataset=my_mr_dataset,
                               batch_size=1, # 只能爲1
                               shuffle=True)

    for i, data in enumerate(train_loader):
        images, gt = data # torch.Tensor[batch, slice, x, y]
        print (np.shape(images))
        print(np.shape(gt))
        print ()

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章