from torch.utils.data import Dataset, DataLoader
import numpy as np
import SimpleITK as sitk
import os
class My_MR_Dataset(Dataset):
"""
下载数据、初始化数据,都可以在这里完成
"""
def __init__(self, src_folder, params):
self.params = params
self.src_folder = src_folder
self.loadTrainingData()
self.images_numpy = self.getNumpyImages()
self.gt_numpy = self.getNumpyGt()
assert len(self.images_numpy) == len(self.gt_numpy)
self.len = len(self.images_numpy)
def __getitem__(self, index):
return self.images_numpy[list(self.images_numpy.keys())[index]], self.gt_numpy[list(self.gt_numpy.keys())[index]]
def __len__(self):
return self.len
def get_new_spacing(self, sitkImage, new_size):
center = np.array(sitkImage.GetOrigin())
old_size = np.array(sitkImage.GetSize())
old_spacing = np.array(sitkImage.GetSpacing())
new_size[-1] = old_size[-1]
x_rate = 0.45
y_rate = 0.45
old_size[0] = old_size[0] * x_rate
old_size[1] = old_size[1] * y_rate
new_spacing = ((old_size * old_spacing + center) - center) / new_size
new_spacing[-1] = old_spacing[-1]
return new_spacing
def getNumpyImages(self):
data = self.getNumpyData(self.sitk_images, sitk.sitkLinear)
return data
def getNumpyGt(self):
data = self.getNumpyData(self.sitk_gt, sitk.sitkLinear)
for key in data:
data[key] = (data[key]>0.5).astype(dtype=np.float32)
return data
def getNumpyData(self, data, method):
'''
# 要与mhd.py中的read_mhd函数共同修改
:param data: sitk_images
:param method:
:return:
'''
ret = dict()
self.sitkImage = dict()
for key in data:
ret[key] = np.zeros([self.params['VolSize'][0], self.params['VolSize'][1], self.params['VolSize'][2]],
dtype=np.float32) # 占位
self.params['dstRes'] = self.get_new_spacing(data[key], self.params['VolSize']) # 求新尺寸下的spacing
img = data[key]
spacing = img.GetSpacing()
# one of the image's spacing is 6
if (spacing[2] == 6):
spacing = (spacing[0], spacing[1], 3)
factor = np.asarray(spacing) / [self.params['dstRes'][0], self.params['dstRes'][1],
self.params['dstRes'][2]] # old spacing和new spacing之比
factor_size = np.asarray(img.GetSize() * factor, dtype=float)
new_size = np.max([factor_size, self.params['VolSize']], axis=0) # 这是先把原图原封不动地读出来,然后在原图的基础上裁剪
new_size = new_size.astype(dtype=int)
T = sitk.AffineTransform(3) # 仿射变换:相当于对于图像做了一个平移、旋转、放缩、剪切、对称。与刚体变换相同的是,可以保持线点之间的平行和共线关系。
T.SetMatrix(img.GetDirection()) # ???
resampler = sitk.ResampleImageFilter() # Resample an image via a coordinate transform,就是对img的一系列操作
resampler.SetReferenceImage(img) # This methods sets the output size, origin, spacing and direction to
# that of the provided image
resampler.SetOutputSpacing([self.params['dstRes'][0], self.params['dstRes'][1], self.params['dstRes'][2]])
resampler.SetSize(new_size.tolist())
resampler.SetInterpolator(method) # 线性插值
if self.params['normDir']:
resampler.SetTransform(T.GetInverse()) # 图像翻转
'''
Organ = img.GetOrigin()
imDepth = img.GetDepth() * img.GetSpacing()[2]
outputDepth = self.params['dstRes'][2] * self.params['VolSize'][2]
if imDepth > outputDepth:
deltaDepth = np.round((imDepth - outputDepth) / 2 / img.GetSpacing()[2]) * img.GetSpacing()[2]
Organ = [Organ[0], Organ[1], Organ[2] - deltaDepth]
resampler.SetOutputOrigin(Organ)
'''
img_resampled = resampler.Execute(img)
img_centroid = np.asarray(new_size, dtype=float) / 2.0 # img中心点座标
img_start_px = (img_centroid - self.params['VolSize'] / 2.0).astype(dtype=int) # 裁剪框左上角起始点
img_start_px[1] = img_start_px[1] * 0.5 # 将裁剪框从中央向上移动一点
# 提取一个感兴趣的区域,其实就是裁剪一个区域
region_extractor = sitk.RegionOfInterestImageFilter()
region_extractor.SetSize(self.params['VolSize'].tolist()) # 截取region的大小,和VolSize一样大
region_extractor.SetIndex(img_start_px.tolist()) # 截取region的起始点
img_resampled_cropped = region_extractor.Execute(img_resampled)
self.sitkImage[key] = img_resampled_cropped
ret[key] = sitk.GetArrayFromImage(img_resampled_cropped).astype(dtype=float)
return ret
def loadTrainingData(self):
self.createImageFileList()
self.createGTFileList()
self.loadImages()
self.loadGT()
def createGTFileList(self):
self.gt_list = list()
for f in self.file_list:
filename, ext = os.path.splitext(f)
self.gt_list.append(os.path.join(filename + '_segmentation' + ext))
def createImageFileList(self):
self.file_list = [f for f in os.listdir(self.src_folder)
if os.path.isfile(os.path.join(self.src_folder, f)) and 'segmentation' not in f and 'raw' not in f ]
print('File List:' + str(self.file_list))
def loadImages(self):
self.sitk_images = dict()
rescal_filt = sitk.RescaleIntensityImageFilter()
rescal_filt.SetOutputMaximum(1)
rescal_filt.SetOutputMinimum(0)
stats = sitk.StatisticsImageFilter()
m = 0.
for f in self.file_list:
self.sitk_images[f] = rescal_filt.Execute(
sitk.Cast(sitk.ReadImage(os.path.join(self.src_folder, f)), sitk.sitkFloat32))
stats.Execute(self.sitk_images[f])
m += stats.GetMean()
self.mean_intensity_train = m / len(self.sitk_images)
def loadGT(self):
self.sitk_gt = dict()
for f in self.gt_list:
self.sitk_gt[f] = sitk.Cast(sitk.ReadImage(os.path.join(self.src_folder, f))>0.5, sitk.sitkFloat32)
if __name__ == '__main__':
##########################################################################################
#
# 注意
# 1、此代码仅仅截取MR图像中的膀胱部分;
# 2、batch_size指病人样本个数。但是由于每个病人样本的切片数不一样,
# 不能将不同切片数的病人放入一个batch中,导致batch size只能设置1。
#
##########################################################################################
src_folder = '/home/zhangxuchang/laboratory/changv-net/data/test'
params = dict()
params['VolSize'] = np.asarray([128, 128, 64], dtype=int)
params['normDir'] = False # if rotates the volume according to its transformation in the mhd file. Not reccommended.
# 即图像翻转用的,暂时不考虑
my_mr_dataset = My_MR_Dataset(src_folder, params)
train_loader = DataLoader(dataset=my_mr_dataset,
batch_size=1, # 只能为1
shuffle=True)
for i, data in enumerate(train_loader):
images, gt = data # torch.Tensor[batch, slice, x, y]
print (np.shape(images))
print(np.shape(gt))
print ()
读取MR图像为torch.Tensor格式
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.