《深入理解並實現Att U-Net》

Introduction

       《Attention U-Net: Learning Where to Look for the Pancreas》發表於2018CVPR,相對於定位+分割的級聯方式實現的精準語義分割,通過添加attention gate的方式隱式的學習對於胰腺分割重要的特徵,省略不重要的特徵。

Network Architecture

       網絡結構類似於U-net,區別是U-net在解碼器和編碼器同層級只進行拼接操作,此處則加入了attention gate,用於將同層級的encoder部分編碼進行加權;

 

Attention gate

Attention的計算過程爲:首先將decoder部分對應的輸出g進行上採樣+卷積(作用和反捲積類似),加捲積的目的是使兩部分的channel數一致;然後分別使用1維卷積將g和x(來自於同層次的encoder)進行降維,channel數變爲原來的1/2。然後兩部分結果相加,經過激活函數,一維卷積將channel數降至1,然後經過sigmod函數,得到和x大小相同的1維的attention map,和原來的x做element-wise的乘法,得到加權後的向量;

網絡結構實現:

import torch
from torch import nn
from torchvision import models
from torch.nn.functional import upsample, normalize
from backbone import VGGBlock, Attention_block, ResNeXtBLOCK, ResNetBLOCK, DANetHead

__all__ = ['UNet', 'ATT_UNet', 'NestedUNet', 'DANet']

blockdict = {
    'VGGBlock': VGGBlock,
    'ResNeXtBLOCK': ResNeXtBLOCK,
    'ResNetBLOCK': ResNetBLOCK,
    'Attention_block': Attention_block
}

class ResNeXtBLOCK(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels, cardinality=16):
        super().__init__()

        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_channels, middle_channels, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(middle_channels)
        self.conv2 = nn.Conv2d(middle_channels, middle_channels, 3, padding=1, groups=cardinality)
        self.bn2 = nn.BatchNorm2d(middle_channels)
        self.conv3 = nn.Conv2d(middle_channels, out_channels, 1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out1 = self.relu(self.bn1(self.conv1(x)))
        out2 = self.relu(self.bn2(self.conv2(out1)))
        out3 = self.bn3(self.conv3(out2))
        out3 += out1
        out = self.relu(out3)

        return out


class Attention_block(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(Attention_block, self).__init__()

        self.Up_conv = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(F_g, F_g, 1, bias=False)
        )

        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g0 = self.Up_conv(g)
        g1 = self.W_g(g0)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)


        return x * psi

class ATT_UNet(nn.Module):
    def __init__(self, num_classes, block, input_channels=3, backbone='ResNeXtBLOCK', att='Attention_block', **kwargs):
        super().__init__()
        backbone = blockdict[block]
        att = blockdict[att]
        nb_filter = [32, 64, 128, 256, 512]

        self.pool = nn.MaxPool2d(2, 2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv0_0 = backbone(input_channels, nb_filter[0], nb_filter[0]) # 3 32
        self.conv0_1 = backbone(nb_filter[0], nb_filter[1], nb_filter[1])   # 32 64
        self.conv1_2 = backbone(nb_filter[1], nb_filter[2], nb_filter[2])   # 64 128
        self.conv2_3 = backbone(nb_filter[2], nb_filter[3], nb_filter[3])   # 128 256
        self.conv3_4 = backbone(nb_filter[3], nb_filter[4], nb_filter[4])   # 256 512

        self.att4_3 = att(nb_filter[4], nb_filter[3], nb_filter[3]//2)
        self.conv4_3 = backbone(nb_filter[4], nb_filter[3], nb_filter[3])
        self.conv3_3 = backbone(nb_filter[4], nb_filter[3], nb_filter[3])

        self.att3_2 = att(nb_filter[3], nb_filter[2], nb_filter[2]//2)
        self.conv3_2 = backbone(nb_filter[3], nb_filter[2], nb_filter[2])
        self.conv2_2 = backbone(nb_filter[3], nb_filter[2], nb_filter[2])

        self.att2_1 = att(nb_filter[2], nb_filter[1], nb_filter[1]//2)
        self.conv2_1 = backbone(nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv1_1 = backbone(nb_filter[2], nb_filter[1], nb_filter[1])

        self.att1_0 = att(nb_filter[1], nb_filter[0], nb_filter[0]//2)
        self.conv1_0 = backbone(nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv0 = backbone(nb_filter[1], nb_filter[0], nb_filter[0])

        self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)


    def forward(self, input):
        x0 = self.conv0_0(input)
        x1 = self.conv0_1(self.pool(x0))
        x2 = self.conv1_2(self.pool(x1))
        x3 = self.conv2_3(self.pool(x2))
        x4 = self.conv3_4(self.pool(x3))

        x_3 = self.conv3_3(torch.cat([self.att4_3(x4, x3), self.conv4_3(self.up(x4))], 1))
        x_2 = self.conv2_2(torch.cat([self.att3_2(x_3, x2), self.conv3_2(self.up(x_3))], 1))
        x_1 = self.conv1_1(torch.cat([self.att2_1(x_2, x1), self.conv2_1(self.up(x_2))], 1))
        x_0 = self.conv0(torch.cat([self.att1_0(x_1, x0), self.conv1_0(self.up(x_1))], 1))

        output = self.final(x_0)
        return output

 

TRAIN

最終的訓練結果:

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章