用U-net做BraTS2019

研二了,第一次用深度學習做分割,感覺自己落伍了好多,方向是圖像處理,卻一直在用傳統法拼拼湊湊,同學都說深度學習要發文章得有好的數學基礎,自知數學基礎差的情況下還是要接觸一下的,畢竟萬事開頭難,不學習就永遠不會,那麼閒話不多說,記錄一下自己的學習過程。
(環境=py3.7+pytorch+spyder)

首先根據前人經驗先搭建網絡,下圖是Unet的網絡結構圖:
在這裏插入圖片描述
觀察到conv操作蠻多的,不管是下采樣層還是反捲積層中都用到,那麼先寫一個該操作的class打包一下:

class Conv3x3(nn.Module):
    def __init__(self, inputCh, outputCh):
        super(Conv3x3, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(inputCh, outputCh, kernel_size=3, stride=1, padding=1),#卷積核3x3,in->out
            nn.BatchNorm2d(outPutCh),#規範化
            nn.ReLU(inplace=True),#激活函數ReLU
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(outputCh, outputCh, kernel_size=3, stride=1, padding=1),#根據圖,上一次的out->out
            nn.BatchNorm2d(outputCh),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):#前向傳播
        x = self.conv1(x)
        x = self.conv2(x)
        return x

打包完卷積的操作之後,再把上採樣的操作整理一下:

class TransConv(nn.Module):
    def __init__(self, inputCh, outputCh):
        super(TransConv, self).__init__()
        self.conv = nn.Sequential(
            nn.ConvTranspose2d(inputCh, outputCh, kernel_size=3, stride=2, padding=1, output_padding=1, dilation=1),
            nn.BatchNorm2d(outputCh),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = self.conv(x)
        return x

class UpSam(nn.Module):
    def __init__(self, inputCh, outputCh):
        super(UpSam, self).__init__()
        self.upconv = TransConv(inputCh, outputCh)#反捲積
        self.conv = Conv3x3(2 * outputCh, outputCh)#這裏用到上面寫的conv操作

    def forward(self, x, convfeatures):
        x = self.upconv(x)
        x = torch.cat([x, convfeatures], dim=1)
        x = self.conv(x)
        return x

至此完成圖中藍色箭頭,灰色箭頭,綠色箭頭的定義,紅色箭頭是maxpool,實質是下采樣,可以跟其他block組合到一起,整體網絡如下:

class UNet(nn.Module):
    def __init__(self, inputCh=4, outputCh=5, size=64):#4種模態數據,擬輸出5個類別(label數據0~4表示:背景、壞死組織、囊腫、腫瘤核心、整體腫瘤)
        super(UNet, self).__init__()
        channels = []
        for i in range(5):
            channels.append((2 ** i) * size)#對應圖像的size
        self.downLayer1 = Conv3x3(inputCh, channels[0])
        self.downLayer2 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
                                        Conv3x3(channels[0], channels[1]))

        self.downLayer3 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
                                        Conv3x3(channels[1], channels[2]))

        self.downLayer4 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
                                        Conv3x3(channels[2], channels[3]))

        self.bottomLayer = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
                                        Conv3x3(channels[3], channels[4]))

        self.upLayer1 = UpSam(channels[4], channels[3]) 
        self.upLayer2 = UpSam(channels[3], channels[2])
        self.upLayer3 = UpSam(channels[2], channels[1])
        self.upLayer4 = UpSam(channels[1], channels[0])

        self.outLayer = nn.Conv2d(channels[0], outputCh, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
    #前半條路
        x1 = self.downLayer1(x)     # size(32)   * 16    * W    * H
        x2 = self.downLayer2(x1)    # size(64)   * 16/2  * W/2  * H/2
        x3 = self.downLayer3(x2)    # size(128)  * 16/4  * W/4  * H/4
        x4 = self.downLayer4(x3)    # size(256)  * 16/8  * W/8  * H/8
	#最底層
        x5 = self.bottomLayer(x4)   # size(512)  * 16/16 * W/16 * H/16
	#後半條路
        x = self.upLayer1(x5, x4)   # size(256)  * 16/8 * W/8 * H/8
        x = self.upLayer2(x, x3)    # size(128)  * 16/4 * W/4 * H/4
        x = self.upLayer3(x, x2)    # size(64)   * 16/2 * W/2 * H/2
        x = self.upLayer4(x, x1)    # size(32)   * 16   * W   * H
        x = self.outLayer(x)        # outputCh(2 )   * 16   * W   * H
        return x

網絡構建完畢,寫個main函數驗證下看看:

if __name__ == "__main__":
    net = UNet(4, 5, degree=64)
    batch_size = 4
    a = torch.randn(batch_size, 4, 192, 192)#隨便搞點數據扔進去
    b = net(a)
    print(a.shape)
    print(b.shape)

在這裏插入圖片描述
可以從調試結果看到,網絡輸出的結果與網絡的輸入是同維度的(4,192,192),輸入包含4個模態,輸出包含了5個類別,這與我們期望的結果吻合,應該沒啥事問題,那麼就準備寫DataLoader了;
個人感覺寫DataLoader最重要的是記得自己要輸入什麼樣的數據,輸出怎麼樣的矩陣,現在還在寫,讀取文件有點麻煩,改天再更新。

import sys
sys.path.append("..")

import math
import torch
import torch.nn as nn

class ConvBlock2d(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(ConvBlock2d, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x


class ConvTrans2d(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(ConvTrans2d, self).__init__()
        self.conv1 = nn.Sequential(
            nn.ConvTranspose2d(in_ch, out_ch, kernel_size=3, stride=2, padding=1, output_padding=1, dilation=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = self.conv1(x)
        return x


class UpBlock2d(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(UpBlock2d, self).__init__()
        self.up_conv = ConvTrans2d(in_ch, out_ch)
        self.conv = ConvBlock2d(2 * out_ch, out_ch)

    def forward(self, x, down_features):
        x = self.up_conv(x)
        x = torch.cat([x, down_features], dim=1)
        x = self.conv(x)
        return x

class UNet2D(nn.Module):
    def __init__(self, in_ch=4, out_ch=2, degree=64):
        super(UNet2D, self).__init__()

        chs = []
        for i in range(5):
            chs.append((2 ** i) * degree)

        self.downLayer1 = ConvBlock2d(in_ch, chs[0])
        self.downLayer2 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
                                        ConvBlock2d(chs[0], chs[1]))

        self.downLayer3 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
                                        ConvBlock2d(chs[1], chs[2]))

        self.downLayer4 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
                                        ConvBlock2d(chs[2], chs[3]))

        self.bottomLayer = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
                                        ConvBlock2d(chs[3], chs[4]))

        self.upLayer1 = UpBlock2d(chs[4], chs[3])
        self.upLayer2 = UpBlock2d(chs[3], chs[2])
        self.upLayer3 = UpBlock2d(chs[2], chs[1])
        self.upLayer4 = UpBlock2d(chs[1], chs[0])

        self.outLayer = nn.Conv2d(chs[0], out_ch, kernel_size=3, stride=1, padding=1)


    def forward(self, x):
        x1 = self.downLayer1(x)     
        x2 = self.downLayer2(x1)    
        x3 = self.downLayer3(x2)    
        x4 = self.downLayer4(x3)    
        x5 = self.bottomLayer(x4)   

        x = self.upLayer1(x5, x4)   
        x = self.upLayer2(x, x3)    
        x = self.upLayer3(x, x2)    
        x = self.upLayer4(x, x1)    
        x = self.outLayer(x)        
        return x


if __name__ == "__main__":
    net = UNet2D(4, 5, degree=64)
    print("total parameter:" + str(netSize(net)))  

    batch_size = 4
    a = torch.randn(batch_size, 4, 192, 192)
    b = net(a)
    print(b.shape)


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章