from torch.utils.data import Dataset, DataLoader
import numpy as np
import SimpleITK as sitk
import os
class My_MR_Dataset(Dataset):
"""
下載數據、初始化數據,都可以在這裏完成
"""
def __init__(self, src_folder, params):
self.params = params
self.src_folder = src_folder
self.loadTrainingData()
self.images_numpy = self.getNumpyImages()
self.gt_numpy = self.getNumpyGt()
assert len(self.images_numpy) == len(self.gt_numpy)
self.len = len(self.images_numpy)
def __getitem__(self, index):
return self.images_numpy[list(self.images_numpy.keys())[index]], self.gt_numpy[list(self.gt_numpy.keys())[index]]
def __len__(self):
return self.len
def get_new_spacing(self, sitkImage, new_size):
center = np.array(sitkImage.GetOrigin())
old_size = np.array(sitkImage.GetSize())
old_spacing = np.array(sitkImage.GetSpacing())
new_size[-1] = old_size[-1]
x_rate = 0.45
y_rate = 0.45
old_size[0] = old_size[0] * x_rate
old_size[1] = old_size[1] * y_rate
new_spacing = ((old_size * old_spacing + center) - center) / new_size
new_spacing[-1] = old_spacing[-1]
return new_spacing
def getNumpyImages(self):
data = self.getNumpyData(self.sitk_images, sitk.sitkLinear)
return data
def getNumpyGt(self):
data = self.getNumpyData(self.sitk_gt, sitk.sitkLinear)
for key in data:
data[key] = (data[key]>0.5).astype(dtype=np.float32)
return data
def getNumpyData(self, data, method):
'''
# 要與mhd.py中的read_mhd函數共同修改
:param data: sitk_images
:param method:
:return:
'''
ret = dict()
self.sitkImage = dict()
for key in data:
ret[key] = np.zeros([self.params['VolSize'][0], self.params['VolSize'][1], self.params['VolSize'][2]],
dtype=np.float32) # 佔位
self.params['dstRes'] = self.get_new_spacing(data[key], self.params['VolSize']) # 求新尺寸下的spacing
img = data[key]
spacing = img.GetSpacing()
# one of the image's spacing is 6
if (spacing[2] == 6):
spacing = (spacing[0], spacing[1], 3)
factor = np.asarray(spacing) / [self.params['dstRes'][0], self.params['dstRes'][1],
self.params['dstRes'][2]] # old spacing和new spacing之比
factor_size = np.asarray(img.GetSize() * factor, dtype=float)
new_size = np.max([factor_size, self.params['VolSize']], axis=0) # 這是先把原圖原封不動地讀出來,然後在原圖的基礎上裁剪
new_size = new_size.astype(dtype=int)
T = sitk.AffineTransform(3) # 仿射變換:相當於對於圖像做了一個平移、旋轉、放縮、剪切、對稱。與剛體變換相同的是,可以保持線點之間的平行和共線關係。
T.SetMatrix(img.GetDirection()) # ???
resampler = sitk.ResampleImageFilter() # Resample an image via a coordinate transform,就是對img的一系列操作
resampler.SetReferenceImage(img) # This methods sets the output size, origin, spacing and direction to
# that of the provided image
resampler.SetOutputSpacing([self.params['dstRes'][0], self.params['dstRes'][1], self.params['dstRes'][2]])
resampler.SetSize(new_size.tolist())
resampler.SetInterpolator(method) # 線性插值
if self.params['normDir']:
resampler.SetTransform(T.GetInverse()) # 圖像翻轉
'''
Organ = img.GetOrigin()
imDepth = img.GetDepth() * img.GetSpacing()[2]
outputDepth = self.params['dstRes'][2] * self.params['VolSize'][2]
if imDepth > outputDepth:
deltaDepth = np.round((imDepth - outputDepth) / 2 / img.GetSpacing()[2]) * img.GetSpacing()[2]
Organ = [Organ[0], Organ[1], Organ[2] - deltaDepth]
resampler.SetOutputOrigin(Organ)
'''
img_resampled = resampler.Execute(img)
img_centroid = np.asarray(new_size, dtype=float) / 2.0 # img中心點座標
img_start_px = (img_centroid - self.params['VolSize'] / 2.0).astype(dtype=int) # 裁剪框左上角起始點
img_start_px[1] = img_start_px[1] * 0.5 # 將裁剪框從中央向上移動一點
# 提取一個感興趣的區域,其實就是裁剪一個區域
region_extractor = sitk.RegionOfInterestImageFilter()
region_extractor.SetSize(self.params['VolSize'].tolist()) # 截取region的大小,和VolSize一樣大
region_extractor.SetIndex(img_start_px.tolist()) # 截取region的起始點
img_resampled_cropped = region_extractor.Execute(img_resampled)
self.sitkImage[key] = img_resampled_cropped
ret[key] = sitk.GetArrayFromImage(img_resampled_cropped).astype(dtype=float)
return ret
def loadTrainingData(self):
self.createImageFileList()
self.createGTFileList()
self.loadImages()
self.loadGT()
def createGTFileList(self):
self.gt_list = list()
for f in self.file_list:
filename, ext = os.path.splitext(f)
self.gt_list.append(os.path.join(filename + '_segmentation' + ext))
def createImageFileList(self):
self.file_list = [f for f in os.listdir(self.src_folder)
if os.path.isfile(os.path.join(self.src_folder, f)) and 'segmentation' not in f and 'raw' not in f ]
print('File List:' + str(self.file_list))
def loadImages(self):
self.sitk_images = dict()
rescal_filt = sitk.RescaleIntensityImageFilter()
rescal_filt.SetOutputMaximum(1)
rescal_filt.SetOutputMinimum(0)
stats = sitk.StatisticsImageFilter()
m = 0.
for f in self.file_list:
self.sitk_images[f] = rescal_filt.Execute(
sitk.Cast(sitk.ReadImage(os.path.join(self.src_folder, f)), sitk.sitkFloat32))
stats.Execute(self.sitk_images[f])
m += stats.GetMean()
self.mean_intensity_train = m / len(self.sitk_images)
def loadGT(self):
self.sitk_gt = dict()
for f in self.gt_list:
self.sitk_gt[f] = sitk.Cast(sitk.ReadImage(os.path.join(self.src_folder, f))>0.5, sitk.sitkFloat32)
if __name__ == '__main__':
##########################################################################################
#
# 注意
# 1、此代碼僅僅截取MR圖像中的膀胱部分;
# 2、batch_size指病人樣本個數。但是由於每個病人樣本的切片數不一樣,
# 不能將不同切片數的病人放入一個batch中,導致batch size只能設置1。
#
##########################################################################################
src_folder = '/home/zhangxuchang/laboratory/changv-net/data/test'
params = dict()
params['VolSize'] = np.asarray([128, 128, 64], dtype=int)
params['normDir'] = False # if rotates the volume according to its transformation in the mhd file. Not reccommended.
# 即圖像翻轉用的,暫時不考慮
my_mr_dataset = My_MR_Dataset(src_folder, params)
train_loader = DataLoader(dataset=my_mr_dataset,
batch_size=1, # 只能爲1
shuffle=True)
for i, data in enumerate(train_loader):
images, gt = data # torch.Tensor[batch, slice, x, y]
print (np.shape(images))
print(np.shape(gt))
print ()
讀取MR圖像爲torch.Tensor格式
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.