構建2D U-net來做BraTS2019

研二了,第一次用深度學習做分割,感覺自己落伍了好多,方向是圖像處理,一直在用傳統法拼拼湊湊,同學都說深度學習要發文章得有好的數學基礎,自知數學基礎差的情況下還是要接觸一下的,畢竟萬事開頭難,不學習就永遠不會,記錄一下自己的學習過程。
(環境

網絡結構

首先根據前人經驗先搭建網絡,下圖是Unet的網絡結構圖:
在這裏插入圖片描述
觀察到conv操作蠻多的,不管是下采樣層還是反捲積層中都用到,那麼先寫一個該操作的class打包一下:

class Conv3x3(nn.Module):
    def __init__(self, inputCh, outputCh):
        super(Conv3x3, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(inputCh, outputCh, kernel_size=3, stride=1, padding=1),#卷積核3x3,in->out
            nn.BatchNorm2d(outPutCh),#規範化
            nn.ReLU(inplace=True),#激活函數ReLU
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(outputCh, outputCh, kernel_size=3, stride=1, padding=1),#根據圖,上一次的out->out
            nn.BatchNorm2d(outputCh),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):#前向傳播
        x = self.conv1(x)
        x = self.conv2(x)
        return x

打包完卷積的操作之後,再把上採樣的操作整理一下:

class TransConv(nn.Module):
    def __init__(self, inputCh, outputCh):
        super(TransConv, self).__init__()
        self.conv = nn.Sequential(
            nn.ConvTranspose2d(inputCh, outputCh, kernel_size=3, stride=2, padding=1, output_padding=1, dilation=1),
            nn.BatchNorm2d(outputCh),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = self.conv(x)
        return x

class UpSam(nn.Module):
    def __init__(self, inputCh, outputCh):
        super(UpSam, self).__init__()
        self.upconv = TransConv(inputCh, outputCh)#反捲積
        self.conv = Conv3x3(2 * outputCh, outputCh)#這裏用到上面寫的conv操作

    def forward(self, x, convfeatures):
        x = self.upconv(x)
        x = torch.cat([x, convfeatures], dim=1)
        x = self.conv(x)
        return x

至此完成圖中藍色箭頭,灰色箭頭,綠色箭頭的定義,紅色箭頭是maxpool,實質是下采樣,可以跟其他block組合到一起,整體網絡如下:

class UNet(nn.Module):
    def __init__(self, inputCh=4, outputCh=5, size=64):#4種模態數據,擬輸出5個類別(label數據0~4表示:背景、壞死組織、囊腫、腫瘤核心、整體腫瘤)
        super(UNet, self).__init__()
        channels = []
        for i in range(5):
            channels.append((2 ** i) * size)#對應圖像的size
        self.downLayer1 = Conv3x3(inputCh, channels[0])
        self.downLayer2 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
                                        Conv3x3(channels[0], channels[1]))

        self.downLayer3 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
                                        Conv3x3(channels[1], channels[2]))

        self.downLayer4 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
                                        Conv3x3(channels[2], channels[3]))

        self.bottomLayer = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
                                        Conv3x3(channels[3], channels[4]))

        self.upLayer1 = UpSam(channels[4], channels[3]) 
        self.upLayer2 = UpSam(channels[3], channels[2])
        self.upLayer3 = UpSam(channels[2], channels[1])
        self.upLayer4 = UpSam(channels[1], channels[0])

        self.outLayer = nn.Conv2d(channels[0], outputCh, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
    #前半條路
        x1 = self.downLayer1(x)     # size(32)   * 16    * W    * H
        x2 = self.downLayer2(x1)    # size(64)   * 16/2  * W/2  * H/2
        x3 = self.downLayer3(x2)    # size(128)  * 16/4  * W/4  * H/4
        x4 = self.downLayer4(x3)    # size(256)  * 16/8  * W/8  * H/8
	#最底層
        x5 = self.bottomLayer(x4)   # size(512)  * 16/16 * W/16 * H/16
	#後半條路
        x = self.upLayer1(x5, x4)   # size(256)  * 16/8 * W/8 * H/8
        x = self.upLayer2(x, x3)    # size(128)  * 16/4 * W/4 * H/4
        x = self.upLayer3(x, x2)    # size(64)   * 16/2 * W/2 * H/2
        x = self.upLayer4(x, x1)    # size(32)   * 16   * W   * H
        x = self.outLayer(x)        # outputCh(2 )   * 16   * W   * H
        return x

網絡構建完畢,寫個main函數驗證下看看:

if __name__ == "__main__":
    net = UNet(4, 5, degree=64)
    batch_size = 4
    a = torch.randn(batch_size, 4, 192, 192)#隨便搞點數據扔進去
    b = net(a)
    print(a.shape)
    print(b.shape)

在這裏插入圖片描述
可以從調試結果看到,網絡輸出的結果與網絡的輸入是同維度的(4,192,192),輸入包含4個模態,輸出包含了5個類別,這與我們期望的結果吻合,應該沒啥事問題,那麼就準備寫DataLoader了;

代碼

以下這部分是2D-UNet網絡的代碼,如果覺得卷積操作和上採樣操作比較佔篇幅,寫個import包括進去就好了;

import sys
import math
import torch
import torch.nn as nn

class Conv3x3(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Conv3x3, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x


class TransConv(nn.Module):
    def __init__(self, inputCh, outputCh):
        super(TransConv, self).__init__()
        self.conv = nn.Sequential(
            nn.ConvTranspose2d(inputCh, outputCh, kernel_size=3, stride=2, padding=1, output_padding=1, dilation=1),
            nn.BatchNorm2d(outputCh),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = self.conv(x)
        return x

class UpSam(nn.Module):
    def __init__(self, inputCh, outputCh):
        super(UpSam, self).__init__()
        self.upconv = TransConv(inputCh, outputCh)#反捲積
        self.conv = Conv3x3(2 * outputCh, outputCh)#這裏用到上面寫的conv操作

    def forward(self, x, convfeatures):
        x = self.upconv(x)
        x = torch.cat([x, convfeatures], dim=1)
        x = self.conv(x)
        return x

class UNet2D(nn.Module):
    def __init__(self, in_ch=4, out_ch=2, degree=64):
        super(UNet2D, self).__init__()

        chs = []
        for i in range(5):
            chs.append((2 ** i) * degree)

        self.downLayer1 = Conv3x3(in_ch, chs[0])
        self.downLayer2 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
                                        Conv3x3(chs[0], chs[1]))

        self.downLayer3 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
                                        Conv3x3(chs[1], chs[2]))

        self.downLayer4 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
                                        Conv3x3(chs[2], chs[3]))

        self.bottomLayer = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
                                        Conv3x3(chs[3], chs[4]))

        self.upLayer1 = Upsam(chs[4], chs[3])
        self.upLayer2 = Upsam(chs[3], chs[2])
        self.upLayer3 = Upsam(chs[2], chs[1])
        self.upLayer4 = Upsam(chs[1], chs[0])

        self.outLayer = nn.Conv2d(chs[0], out_ch, kernel_size=3, stride=1, padding=1)


    def forward(self, x):
        x1 = self.downLayer1(x)     
        x2 = self.downLayer2(x1)    
        x3 = self.downLayer3(x2)    
        x4 = self.downLayer4(x3)    
        x5 = self.bottomLayer(x4)   

        x = self.upLayer1(x5, x4)   
        x = self.upLayer2(x, x3)    
        x = self.upLayer3(x, x2)    
        x = self.upLayer4(x, x1)    
        x = self.outLayer(x)        
        return x


if __name__ == "__main__":
    net = UNet2D(4, 5, degree=64)
    batch_size = 4
    a = torch.randn(batch_size, 4, 192, 192)
    b = net(a)
    print(b.shape)


數據讀入

寫好網絡以後,就該寫讀入方式,這裏參考了並修改了別人的讀寫方式,有點點繁瑣,首先需要獲取HGG/LGG的全部文件夾名字,具體如何獲取可以搜索“獲取文件夾下的子文件夾名字”,然後pip install SimpleITK,如果下載得慢可以換清華鏡像源,用SimpleITK讀入nii文件也比較容易。

讀入NII

def load_nii_as_array(img_name):
    img = sitk.ReadImage(img_name)#img_name是文件路徑
    nda = sitk.GetArrayFromImage(img) #返回[155,240,240]的ndarray類型
    return nda

簡單的數據歸一化

def norm_vol(data):
    data = data.astype(np.float)
    index = data.nonzero()#創建與data一樣的掩模圖,作爲索引
    smax = np.max(data[index])#在Data裏找最大
    smin = np.min(data[index])
    if smax - smin == 0:#如果圖像是背景全0的情況,不作歸一化
        return data
    else:
        data[index] = (data[index] - smin * 1.0) / (smax - smin)
        return data

讀入的類

需要注意的是,我把所有文件名寫成了conf文件,內容是這樣的:

HGG/BraTS19_2013_10_1
HGG/BraTS19_2013_11_1
HGG/BraTS19_2013_12_1
HGG/BraTS19_2013_13_1
HGG/BraTS19_2013_14_1
HGG/BraTS19_2013_17_1
HGG/BraTS19_2013_18_1
HGG/BraTS19_2013_19_1
HGG/BraTS19_2013_20_1
HGG/BraTS19_2013_21_1
HGG/BraTS19_2013_22_1
HGG/BraTS19_2013_23_1
HGG/BraTS19_2013_25_1
HGG/BraTS19_2013_26_1
HGG/BraTS19_2013_27_1
class DataLoader19(Dataset):
    def __init__(self, data_dir, conf='../config/train19.conf', train=True):
        img_lists = []
        train_config = open(conf).readlines()
        for data in train_config:
            img_lists.append(os.path.join(data_dir, data.strip('\n')))
        
        self.data = []
        self.freq = np.zeros(5)
        self.zero_vol = np.zeros((4, 240, 240))
        count = 0
        for subject in img_lists:
            count += 1
            if count % 10 == 0:
                print('loading imageSets %d' %count)
            volume, label = DataLoader19.get_subject(subject)   # 4 * 155 * 240 * 240,  155 * 240 * 240
            volume = norm_vol(volume)

            self.freq += self.get_freq(label)
            if train is True:
                length = volume.shape[1]
                for i in range(length):
                    name = subject + '=slice' + str(i)
                    if (volume[:, i, :, :] == self.zero_vol).all():  # when training, ignore zero data
                        continue
                    else:
                        self.data.append([volume[:, i, :, :], label[i, :, :], name])
            else:
                volume = np.transpose(volume, (1, 0, 2, 3))
                self.data.append([volume, label, subject])
        self.freq = self.freq / np.sum(self.freq)
        self.weight = np.median(self.freq) / self.freq
        print('********  Finish loading data  ********')
        print('********  Weight for all classes  ********')
        print(self.weight)
        if train is True:
            print('********  Total number of 2D images is ' + str(len(self.data)) + ' **********')
        else:
            print('********  Total number of subject is ' + str(len(self.data)) + ' **********')


def __getitem__(self, index):
       
        [image, label, name] = self.data[index]  #獲取單個數據和標籤,包括文件名
        
        image = torch.from_numpy(image).float()  # Float Tensor 4, 240, 240
        label = torch.from_numpy(label).float()    # Float Tensor 240, 240
        return image, label, name

		def get_subject(subject):
        # **************** get file ****************
        files = os.listdir(subject)  #
        multi_mode_dir = []
        label_dir = ""
        for f in files:
            if 'flair' in f :    # if is data or 't1' in f or 't1ce' in f or 't2' in f
                multi_mode_dir.append(f)
            elif 'seg' in f:        # if is label
                label_dir = f

        # ********** load 4 mode images **********
        multi_mode_imgs = []  # list size :4      item size: 155 * 240 * 240
        for mod_dir in multi_mode_dir:
            path = os.path.join(subject, mod_dir)  # absolute directory
            img = load_nii_as_array(path)#+ '/' + mod_dir + '.nii.gz'
            multi_mode_imgs.append(img)

        # ********** get label **********
        label_dir = os.path.join(subject, label_dir)# 
        label = load_nii_as_array(label_dir)  #
        volume = np.asarray(multi_mode_imgs)
        return volume, label

    def get_freq(self, label):
        class_count = np.zeros((5))
        for i in range(5):
            a = (label == i) + 0
            class_count[i] = np.sum(a)
        return class_count

if __name__ == "__main__":
    vol_num = 4
    data_dir = 'MICCAI_BraTS_2018_Data_Training/'#'../data_sample/'
    conf = 'MICCAI_BraTS_2018_Data_Training/config/valid18.config'
    # test for training data
    brats19 = DataLoader19(data_dir=data_dir, conf=conf, train=True)
    image2d, label2d, im_name = brats19[5]

    print('image size ......')
    print(image2d.shape)             # (4,  240, 240)

    print('label size ......')
    print(label2d.shape)             # (240, 240)
    print(im_name)
    name = im_name.split('/')[-1]
 
    test = DataLoader19(data_dir=data_dir, conf=conf, train=False)
    image_volume, label_volume, subject = test[0]
    print(image_volume.shape)
    print(label_volume.shape)
    print(subject)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章