Unet圖像分割網絡Pytorch實現

介紹

之前計劃寫一篇tensorflow實現的,CSDN上已經有大佬寫了。最近一個月接觸了一下Pytorch,個人認爲Pytorch相較於Tensorflow來說好用很多。本文的內容是我對Unet論文的總結與提煉,需要提醒的是,Unet原文發佈的時候還沒有提出BN(Batch Normalization). 所以在本文中我會增加這一個步驟。

如果想要安裝Python和Pytorch或者獲得進一步的信息可以點擊Python ,Pytorch

在圖像分割這個大問題上,主要有兩個流派:U-shape和dialated Conv。本文介紹的是U-shape網絡中最爲經典的U-Net。隨着骨幹網路的進化,很多相應衍生出來的網絡大多都是對於Unet進行了改進但是本質上的思路還是沒有太多的變化。比如結合DenseNet 和Unet的FCDenseNet, Unet++


Unet

Unet是一個爲醫學圖像分割設計的auto-encoder-decoder結構的網絡。行業裏也把它視作一種FCN(fully connected network)。 它可以分成兩個部分,down(encoder) 和 up(decoder)。down的主要結構可以看成conv後面跟maxpool。 up的主要結構是一個upsample後面跟conv。

Unet的核心思想

想要弄清這個問題首先要感性的理解一下卷積的作用。就拿MINIST數據集訓練數字識別這個簡單的CNN網絡爲例, 它把一個28*28的圖片抽象成一個0-9的向量。卷積可以看成是特徵的提取,它可以提取出輸入的信息的抽象概念。但是Pool和Conv會損失空間信息。其中,空間信息在pool的過程中損失的更爲嚴重。對於圖像分割來說, 空間信息和抽象信息同樣重要。既然每一個次pool的時候會嚴重損失空間信息,也就是說maxpool之間的空間信息多於之後的。於是Unet提出,把down的特徵連接到對應的up上。

Unet的結構

Unet
其中灰色箭頭copy and crop中的copy就是concatenatecrop是爲了讓兩者的長寬一致
左半邊就是down path右半邊 就是up path。我們來分別介紹這兩個部分。

Down Path

圖中input image tile就是我們輸入的訓練數據。除了第一層是兩個conv,其他層都可以看成是maxpool後面跟兩個conv。在Unet中絕大部分的conv都是兩個conv連用的形式存在的,爲了方便,我們可以先自定義一個double_conv類。

# 實現double conv
class double_conv(nn.Module):
    ''' Conv => Batch_Norm => ReLU => Conv2d => Batch_Norm => ReLU
    '''
    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
        self.conv.apply(self.init_weights)
    
    def forward(self, x):
        x = self.conv(x)
        return x

    @staticmethod
    def init_weights(m):
        if type(m) == nn.Conv2d:
            init.xavier_normal(m.weight)
            init.constant(m.bias,0)

下面我們來實現input conv, 它實際上用一個double_conv也就完成了。

# 實現input conv
class inconv(nn.Module):
    ''' input conv layer
        let input 3 channels image to 64 channels
        The oly difference between `inconv` and `down` is maxpool layer 
    '''
    def __init__(self, in_ch, out_ch):
        super(inconv, self).__init__()
        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x):
        x = self.conv(x)
        return x

接下來我們來實現down類,它的結構是一個maxpool接一個double_conv

class down(nn.Module):
    ''' normal down path 
        MaxPool2d => double_conv
    '''
    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool2d(2),
            double_conv(in_ch, out_ch)
        )

    def forward(self, x):
        x = self.mpconv(x)
        return x

Up path

Unet的up path主要的結構是upsampl加上double_conv但是也可以使用ConvTranspose2d代替upsample。下面的代碼給出了兩種選擇。
在up path 中,我們需要將down path 中的特徵合並進來。在up.forward中crop從而讓兩個特徵一致。

class up(nn.Module):
    ''' up path
        conv_transpose => double_conv
    '''
    def __init__(self, in_ch, out_ch, Transpose=False):
        super(up, self).__init__()

        #  would be a nice idea if the upsampling could be learned too,
        #  but my machine do not have enough memory to handle all those weights
        if Transpose:
            self.up = nn.ConvTranspose2d(in_ch, in_ch//2, 2, stride=2)
        else:
            # self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
                                    nn.Conv2d(in_ch, in_ch//2, kernel_size=1, padding=0),
                                    nn.ReLU(inplace=True))
        self.conv = double_conv(in_ch, out_ch)
        self.up.apply(self.init_weights)

    def forward(self, x1, x2):
        ''' 
            conv output shape = (input_shape - Filter_shape + 2 * padding)/stride + 1
        '''

        x1 = self.up(x1)

        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = nn.functional.pad(x1, (diffX // 2, diffX - diffX//2,
                                    diffY // 2, diffY - diffY//2))

        x = torch.cat([x2,x1], dim=1)
        x = self.conv(x)
        return x

    @staticmethod
    def init_weights(m):
        if type(m) == nn.Conv2d:
            init.xavier_normal(m.weight)
            init.constant(m.bias,0)

輪子已經造好了,那麼我們來實現Unet讓它跑起來

class Unet(nn.Module):
    def __init__(self, in_ch, out_ch, gpu_ids=[]):
        super(Unet, self).__init__()
        self.loss_stack = 0
        self.matrix_iou_stack = 0
        self.stack_count = 0
        self.display_names = ['loss_stack', 'matrix_iou_stack']
        self.gpu_ids = gpu_ids
        self.bce_loss = nn.BCELoss()
        self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if torch.cuda.is_available() else torch.device('cpu')
        self.inc = inconv(in_ch, 64)
        self.down1 = down(64, 128)
        # print(list(self.down1.parameters()))
        self.down2 = down(128, 256)
        self.down3 = down(256, 512)
        self.drop3 = nn.Dropout2d(0.5)
        self.down4 = down(512, 1024)
        self.drop4 = nn.Dropout2d(0.5)
        self.up1 = up(1024, 512, False)
        self.up2 = up(512, 256, False)
        self.up3 = up(256, 128, False)
        self.up4 = up(128, 64, False)
        self.outc = outconv(64, 1)
        self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        # self.optimizer = torch.optim.SGD(self.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0005)

    def forward(self):
        x1 = self.inc(self.x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x4 = self.drop3(x4)
        x5 = self.down4(x4)
        x5 = self.drop4(x5)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        self.pred_y = nn.functional.sigmoid(x)

    def set_input(self, x, y):
        self.x = x.to(self.device)
        self.y = y.to(self.device)

    def optimize_params(self):
        self.forward()
        self._bce_iou_loss()
        _ = self.accu_iou()
        self.stack_count += 1
        self.zero_grad()
        self.loss.backward()
        self.optimizer.step()

    def accu_iou(self):
        # B is the mask pred, A is the malanoma 
        y_pred = (self.pred_y > 0.5) * 1.0
        y_true = (self.y > 0.5) * 1.0
        pred_flat = y_pred.view(y_pred.numel())
        true_flat = y_true.view(y_true.numel())

        intersection = float(torch.sum(pred_flat * true_flat)) + 1e-7
        denominator = float(torch.sum(pred_flat + true_flat)) - intersection + 2e-7

        self.matrix_iou = intersection/denominator
        self.matrix_iou_stack += self.matrix_iou
        return self.matrix_iou

    def _bce_iou_loss(self):
        y_pred = self.pred_y
        y_true = self.y
        pred_flat = y_pred.view(y_pred.numel())
        true_flat = y_true.view(y_true.numel())

        intersection = torch.sum(pred_flat * true_flat) + 1e-7
        denominator = torch.sum(pred_flat + true_flat) - intersection + 1e-7
        iou = torch.div(intersection, denominator)
        bce_loss = self.bce_loss(pred_flat, true_flat)
        self.loss = bce_loss - iou + 1
        self.loss_stack += self.loss
        
    def get_current_losses(self):
        errors_ret = {}
        for name in self.display_names:
            if isinstance(name, str):
                errors_ret[name] = float(getattr(self, name)) / self.stack_count
        self.loss_stack = 0
        self.matrix_iou_stack = 0
        self.stack_count = 0
        return errors_ret
        
    def eval_iou(self):
        with torch.no_grad():
            self.forward()
            self._bce_iou_loss()
            _ = self.accu_iou()
            self.stack_count += 1

其他的代碼就是很固定的pytorch模板代碼了。

代碼參考自GitHub


轉載請標明出處


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章