构建2D U-net来做BraTS2019

研二了,第一次用深度学习做分割,感觉自己落伍了好多,方向是图像处理,一直在用传统法拼拼凑凑,同学都说深度学习要发文章得有好的数学基础,自知数学基础差的情况下还是要接触一下的,毕竟万事开头难,不学习就永远不会,记录一下自己的学习过程。
(环境

网络结构

首先根据前人经验先搭建网络,下图是Unet的网络结构图:
在这里插入图片描述
观察到conv操作蛮多的,不管是下采样层还是反卷积层中都用到,那么先写一个该操作的class打包一下:

class Conv3x3(nn.Module):
    def __init__(self, inputCh, outputCh):
        super(Conv3x3, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(inputCh, outputCh, kernel_size=3, stride=1, padding=1),#卷积核3x3,in->out
            nn.BatchNorm2d(outPutCh),#规范化
            nn.ReLU(inplace=True),#激活函数ReLU
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(outputCh, outputCh, kernel_size=3, stride=1, padding=1),#根据图,上一次的out->out
            nn.BatchNorm2d(outputCh),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):#前向传播
        x = self.conv1(x)
        x = self.conv2(x)
        return x

打包完卷积的操作之后,再把上采样的操作整理一下:

class TransConv(nn.Module):
    def __init__(self, inputCh, outputCh):
        super(TransConv, self).__init__()
        self.conv = nn.Sequential(
            nn.ConvTranspose2d(inputCh, outputCh, kernel_size=3, stride=2, padding=1, output_padding=1, dilation=1),
            nn.BatchNorm2d(outputCh),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = self.conv(x)
        return x

class UpSam(nn.Module):
    def __init__(self, inputCh, outputCh):
        super(UpSam, self).__init__()
        self.upconv = TransConv(inputCh, outputCh)#反卷积
        self.conv = Conv3x3(2 * outputCh, outputCh)#这里用到上面写的conv操作

    def forward(self, x, convfeatures):
        x = self.upconv(x)
        x = torch.cat([x, convfeatures], dim=1)
        x = self.conv(x)
        return x

至此完成图中蓝色箭头,灰色箭头,绿色箭头的定义,红色箭头是maxpool,实质是下采样,可以跟其他block组合到一起,整体网络如下:

class UNet(nn.Module):
    def __init__(self, inputCh=4, outputCh=5, size=64):#4种模态数据,拟输出5个类别(label数据0~4表示:背景、坏死组织、囊肿、肿瘤核心、整体肿瘤)
        super(UNet, self).__init__()
        channels = []
        for i in range(5):
            channels.append((2 ** i) * size)#对应图像的size
        self.downLayer1 = Conv3x3(inputCh, channels[0])
        self.downLayer2 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
                                        Conv3x3(channels[0], channels[1]))

        self.downLayer3 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
                                        Conv3x3(channels[1], channels[2]))

        self.downLayer4 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
                                        Conv3x3(channels[2], channels[3]))

        self.bottomLayer = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
                                        Conv3x3(channels[3], channels[4]))

        self.upLayer1 = UpSam(channels[4], channels[3]) 
        self.upLayer2 = UpSam(channels[3], channels[2])
        self.upLayer3 = UpSam(channels[2], channels[1])
        self.upLayer4 = UpSam(channels[1], channels[0])

        self.outLayer = nn.Conv2d(channels[0], outputCh, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
    #前半条路
        x1 = self.downLayer1(x)     # size(32)   * 16    * W    * H
        x2 = self.downLayer2(x1)    # size(64)   * 16/2  * W/2  * H/2
        x3 = self.downLayer3(x2)    # size(128)  * 16/4  * W/4  * H/4
        x4 = self.downLayer4(x3)    # size(256)  * 16/8  * W/8  * H/8
	#最底层
        x5 = self.bottomLayer(x4)   # size(512)  * 16/16 * W/16 * H/16
	#后半条路
        x = self.upLayer1(x5, x4)   # size(256)  * 16/8 * W/8 * H/8
        x = self.upLayer2(x, x3)    # size(128)  * 16/4 * W/4 * H/4
        x = self.upLayer3(x, x2)    # size(64)   * 16/2 * W/2 * H/2
        x = self.upLayer4(x, x1)    # size(32)   * 16   * W   * H
        x = self.outLayer(x)        # outputCh(2 )   * 16   * W   * H
        return x

网络构建完毕,写个main函数验证下看看:

if __name__ == "__main__":
    net = UNet(4, 5, degree=64)
    batch_size = 4
    a = torch.randn(batch_size, 4, 192, 192)#随便搞点数据扔进去
    b = net(a)
    print(a.shape)
    print(b.shape)

在这里插入图片描述
可以从调试结果看到,网络输出的结果与网络的输入是同维度的(4,192,192),输入包含4个模态,输出包含了5个类别,这与我们期望的结果吻合,应该没啥事问题,那么就准备写DataLoader了;

代码

以下这部分是2D-UNet网络的代码,如果觉得卷积操作和上采样操作比较占篇幅,写个import包括进去就好了;

import sys
import math
import torch
import torch.nn as nn

class Conv3x3(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Conv3x3, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x


class TransConv(nn.Module):
    def __init__(self, inputCh, outputCh):
        super(TransConv, self).__init__()
        self.conv = nn.Sequential(
            nn.ConvTranspose2d(inputCh, outputCh, kernel_size=3, stride=2, padding=1, output_padding=1, dilation=1),
            nn.BatchNorm2d(outputCh),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = self.conv(x)
        return x

class UpSam(nn.Module):
    def __init__(self, inputCh, outputCh):
        super(UpSam, self).__init__()
        self.upconv = TransConv(inputCh, outputCh)#反卷积
        self.conv = Conv3x3(2 * outputCh, outputCh)#这里用到上面写的conv操作

    def forward(self, x, convfeatures):
        x = self.upconv(x)
        x = torch.cat([x, convfeatures], dim=1)
        x = self.conv(x)
        return x

class UNet2D(nn.Module):
    def __init__(self, in_ch=4, out_ch=2, degree=64):
        super(UNet2D, self).__init__()

        chs = []
        for i in range(5):
            chs.append((2 ** i) * degree)

        self.downLayer1 = Conv3x3(in_ch, chs[0])
        self.downLayer2 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
                                        Conv3x3(chs[0], chs[1]))

        self.downLayer3 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
                                        Conv3x3(chs[1], chs[2]))

        self.downLayer4 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
                                        Conv3x3(chs[2], chs[3]))

        self.bottomLayer = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
                                        Conv3x3(chs[3], chs[4]))

        self.upLayer1 = Upsam(chs[4], chs[3])
        self.upLayer2 = Upsam(chs[3], chs[2])
        self.upLayer3 = Upsam(chs[2], chs[1])
        self.upLayer4 = Upsam(chs[1], chs[0])

        self.outLayer = nn.Conv2d(chs[0], out_ch, kernel_size=3, stride=1, padding=1)


    def forward(self, x):
        x1 = self.downLayer1(x)     
        x2 = self.downLayer2(x1)    
        x3 = self.downLayer3(x2)    
        x4 = self.downLayer4(x3)    
        x5 = self.bottomLayer(x4)   

        x = self.upLayer1(x5, x4)   
        x = self.upLayer2(x, x3)    
        x = self.upLayer3(x, x2)    
        x = self.upLayer4(x, x1)    
        x = self.outLayer(x)        
        return x


if __name__ == "__main__":
    net = UNet2D(4, 5, degree=64)
    batch_size = 4
    a = torch.randn(batch_size, 4, 192, 192)
    b = net(a)
    print(b.shape)


数据读入

写好网络以后,就该写读入方式,这里参考了并修改了别人的读写方式,有点点繁琐,首先需要获取HGG/LGG的全部文件夹名字,具体如何获取可以搜索“获取文件夹下的子文件夹名字”,然后pip install SimpleITK,如果下载得慢可以换清华镜像源,用SimpleITK读入nii文件也比较容易。

读入NII

def load_nii_as_array(img_name):
    img = sitk.ReadImage(img_name)#img_name是文件路径
    nda = sitk.GetArrayFromImage(img) #返回[155,240,240]的ndarray类型
    return nda

简单的数据归一化

def norm_vol(data):
    data = data.astype(np.float)
    index = data.nonzero()#创建与data一样的掩模图,作为索引
    smax = np.max(data[index])#在Data里找最大
    smin = np.min(data[index])
    if smax - smin == 0:#如果图像是背景全0的情况,不作归一化
        return data
    else:
        data[index] = (data[index] - smin * 1.0) / (smax - smin)
        return data

读入的类

需要注意的是,我把所有文件名写成了conf文件,内容是这样的:

HGG/BraTS19_2013_10_1
HGG/BraTS19_2013_11_1
HGG/BraTS19_2013_12_1
HGG/BraTS19_2013_13_1
HGG/BraTS19_2013_14_1
HGG/BraTS19_2013_17_1
HGG/BraTS19_2013_18_1
HGG/BraTS19_2013_19_1
HGG/BraTS19_2013_20_1
HGG/BraTS19_2013_21_1
HGG/BraTS19_2013_22_1
HGG/BraTS19_2013_23_1
HGG/BraTS19_2013_25_1
HGG/BraTS19_2013_26_1
HGG/BraTS19_2013_27_1
class DataLoader19(Dataset):
    def __init__(self, data_dir, conf='../config/train19.conf', train=True):
        img_lists = []
        train_config = open(conf).readlines()
        for data in train_config:
            img_lists.append(os.path.join(data_dir, data.strip('\n')))
        
        self.data = []
        self.freq = np.zeros(5)
        self.zero_vol = np.zeros((4, 240, 240))
        count = 0
        for subject in img_lists:
            count += 1
            if count % 10 == 0:
                print('loading imageSets %d' %count)
            volume, label = DataLoader19.get_subject(subject)   # 4 * 155 * 240 * 240,  155 * 240 * 240
            volume = norm_vol(volume)

            self.freq += self.get_freq(label)
            if train is True:
                length = volume.shape[1]
                for i in range(length):
                    name = subject + '=slice' + str(i)
                    if (volume[:, i, :, :] == self.zero_vol).all():  # when training, ignore zero data
                        continue
                    else:
                        self.data.append([volume[:, i, :, :], label[i, :, :], name])
            else:
                volume = np.transpose(volume, (1, 0, 2, 3))
                self.data.append([volume, label, subject])
        self.freq = self.freq / np.sum(self.freq)
        self.weight = np.median(self.freq) / self.freq
        print('********  Finish loading data  ********')
        print('********  Weight for all classes  ********')
        print(self.weight)
        if train is True:
            print('********  Total number of 2D images is ' + str(len(self.data)) + ' **********')
        else:
            print('********  Total number of subject is ' + str(len(self.data)) + ' **********')


def __getitem__(self, index):
       
        [image, label, name] = self.data[index]  #获取单个数据和标签,包括文件名
        
        image = torch.from_numpy(image).float()  # Float Tensor 4, 240, 240
        label = torch.from_numpy(label).float()    # Float Tensor 240, 240
        return image, label, name

		def get_subject(subject):
        # **************** get file ****************
        files = os.listdir(subject)  #
        multi_mode_dir = []
        label_dir = ""
        for f in files:
            if 'flair' in f :    # if is data or 't1' in f or 't1ce' in f or 't2' in f
                multi_mode_dir.append(f)
            elif 'seg' in f:        # if is label
                label_dir = f

        # ********** load 4 mode images **********
        multi_mode_imgs = []  # list size :4      item size: 155 * 240 * 240
        for mod_dir in multi_mode_dir:
            path = os.path.join(subject, mod_dir)  # absolute directory
            img = load_nii_as_array(path)#+ '/' + mod_dir + '.nii.gz'
            multi_mode_imgs.append(img)

        # ********** get label **********
        label_dir = os.path.join(subject, label_dir)# 
        label = load_nii_as_array(label_dir)  #
        volume = np.asarray(multi_mode_imgs)
        return volume, label

    def get_freq(self, label):
        class_count = np.zeros((5))
        for i in range(5):
            a = (label == i) + 0
            class_count[i] = np.sum(a)
        return class_count

if __name__ == "__main__":
    vol_num = 4
    data_dir = 'MICCAI_BraTS_2018_Data_Training/'#'../data_sample/'
    conf = 'MICCAI_BraTS_2018_Data_Training/config/valid18.config'
    # test for training data
    brats19 = DataLoader19(data_dir=data_dir, conf=conf, train=True)
    image2d, label2d, im_name = brats19[5]

    print('image size ......')
    print(image2d.shape)             # (4,  240, 240)

    print('label size ......')
    print(label2d.shape)             # (240, 240)
    print(im_name)
    name = im_name.split('/')[-1]
 
    test = DataLoader19(data_dir=data_dir, conf=conf, train=False)
    image_volume, label_volume, subject = test[0]
    print(image_volume.shape)
    print(label_volume.shape)
    print(subject)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章