Unet图像分割网络Pytorch实现

介绍

之前计划写一篇tensorflow实现的,CSDN上已经有大佬写了。最近一个月接触了一下Pytorch,个人认为Pytorch相较于Tensorflow来说好用很多。本文的内容是我对Unet论文的总结与提炼,需要提醒的是,Unet原文发布的时候还没有提出BN(Batch Normalization). 所以在本文中我会增加这一个步骤。

如果想要安装Python和Pytorch或者获得进一步的信息可以点击Python ,Pytorch

在图像分割这个大问题上,主要有两个流派:U-shape和dialated Conv。本文介绍的是U-shape网络中最为经典的U-Net。随着骨干网路的进化,很多相应衍生出来的网络大多都是对于Unet进行了改进但是本质上的思路还是没有太多的变化。比如结合DenseNet 和Unet的FCDenseNet, Unet++


Unet

Unet是一个为医学图像分割设计的auto-encoder-decoder结构的网络。行业里也把它视作一种FCN(fully connected network)。 它可以分成两个部分,down(encoder) 和 up(decoder)。down的主要结构可以看成conv后面跟maxpool。 up的主要结构是一个upsample后面跟conv。

Unet的核心思想

想要弄清这个问题首先要感性的理解一下卷积的作用。就拿MINIST数据集训练数字识别这个简单的CNN网络为例, 它把一个28*28的图片抽象成一个0-9的向量。卷积可以看成是特征的提取,它可以提取出输入的信息的抽象概念。但是Pool和Conv会损失空间信息。其中,空间信息在pool的过程中损失的更为严重。对于图像分割来说, 空间信息和抽象信息同样重要。既然每一个次pool的时候会严重损失空间信息,也就是说maxpool之间的空间信息多于之后的。于是Unet提出,把down的特征连接到对应的up上。

Unet的结构

Unet
其中灰色箭头copy and crop中的copy就是concatenatecrop是为了让两者的长宽一致
左半边就是down path右半边 就是up path。我们来分别介绍这两个部分。

Down Path

图中input image tile就是我们输入的训练数据。除了第一层是两个conv,其他层都可以看成是maxpool后面跟两个conv。在Unet中绝大部分的conv都是两个conv连用的形式存在的,为了方便,我们可以先自定义一个double_conv类。

# 实现double conv
class double_conv(nn.Module):
    ''' Conv => Batch_Norm => ReLU => Conv2d => Batch_Norm => ReLU
    '''
    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
        self.conv.apply(self.init_weights)
    
    def forward(self, x):
        x = self.conv(x)
        return x

    @staticmethod
    def init_weights(m):
        if type(m) == nn.Conv2d:
            init.xavier_normal(m.weight)
            init.constant(m.bias,0)

下面我们来实现input conv, 它实际上用一个double_conv也就完成了。

# 实现input conv
class inconv(nn.Module):
    ''' input conv layer
        let input 3 channels image to 64 channels
        The oly difference between `inconv` and `down` is maxpool layer 
    '''
    def __init__(self, in_ch, out_ch):
        super(inconv, self).__init__()
        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x):
        x = self.conv(x)
        return x

接下来我们来实现down类,它的结构是一个maxpool接一个double_conv

class down(nn.Module):
    ''' normal down path 
        MaxPool2d => double_conv
    '''
    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool2d(2),
            double_conv(in_ch, out_ch)
        )

    def forward(self, x):
        x = self.mpconv(x)
        return x

Up path

Unet的up path主要的结构是upsampl加上double_conv但是也可以使用ConvTranspose2d代替upsample。下面的代码给出了两种选择。
在up path 中,我们需要将down path 中的特征合并进来。在up.forward中crop从而让两个特征一致。

class up(nn.Module):
    ''' up path
        conv_transpose => double_conv
    '''
    def __init__(self, in_ch, out_ch, Transpose=False):
        super(up, self).__init__()

        #  would be a nice idea if the upsampling could be learned too,
        #  but my machine do not have enough memory to handle all those weights
        if Transpose:
            self.up = nn.ConvTranspose2d(in_ch, in_ch//2, 2, stride=2)
        else:
            # self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
                                    nn.Conv2d(in_ch, in_ch//2, kernel_size=1, padding=0),
                                    nn.ReLU(inplace=True))
        self.conv = double_conv(in_ch, out_ch)
        self.up.apply(self.init_weights)

    def forward(self, x1, x2):
        ''' 
            conv output shape = (input_shape - Filter_shape + 2 * padding)/stride + 1
        '''

        x1 = self.up(x1)

        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = nn.functional.pad(x1, (diffX // 2, diffX - diffX//2,
                                    diffY // 2, diffY - diffY//2))

        x = torch.cat([x2,x1], dim=1)
        x = self.conv(x)
        return x

    @staticmethod
    def init_weights(m):
        if type(m) == nn.Conv2d:
            init.xavier_normal(m.weight)
            init.constant(m.bias,0)

轮子已经造好了,那么我们来实现Unet让它跑起来

class Unet(nn.Module):
    def __init__(self, in_ch, out_ch, gpu_ids=[]):
        super(Unet, self).__init__()
        self.loss_stack = 0
        self.matrix_iou_stack = 0
        self.stack_count = 0
        self.display_names = ['loss_stack', 'matrix_iou_stack']
        self.gpu_ids = gpu_ids
        self.bce_loss = nn.BCELoss()
        self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if torch.cuda.is_available() else torch.device('cpu')
        self.inc = inconv(in_ch, 64)
        self.down1 = down(64, 128)
        # print(list(self.down1.parameters()))
        self.down2 = down(128, 256)
        self.down3 = down(256, 512)
        self.drop3 = nn.Dropout2d(0.5)
        self.down4 = down(512, 1024)
        self.drop4 = nn.Dropout2d(0.5)
        self.up1 = up(1024, 512, False)
        self.up2 = up(512, 256, False)
        self.up3 = up(256, 128, False)
        self.up4 = up(128, 64, False)
        self.outc = outconv(64, 1)
        self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        # self.optimizer = torch.optim.SGD(self.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0005)

    def forward(self):
        x1 = self.inc(self.x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x4 = self.drop3(x4)
        x5 = self.down4(x4)
        x5 = self.drop4(x5)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        self.pred_y = nn.functional.sigmoid(x)

    def set_input(self, x, y):
        self.x = x.to(self.device)
        self.y = y.to(self.device)

    def optimize_params(self):
        self.forward()
        self._bce_iou_loss()
        _ = self.accu_iou()
        self.stack_count += 1
        self.zero_grad()
        self.loss.backward()
        self.optimizer.step()

    def accu_iou(self):
        # B is the mask pred, A is the malanoma 
        y_pred = (self.pred_y > 0.5) * 1.0
        y_true = (self.y > 0.5) * 1.0
        pred_flat = y_pred.view(y_pred.numel())
        true_flat = y_true.view(y_true.numel())

        intersection = float(torch.sum(pred_flat * true_flat)) + 1e-7
        denominator = float(torch.sum(pred_flat + true_flat)) - intersection + 2e-7

        self.matrix_iou = intersection/denominator
        self.matrix_iou_stack += self.matrix_iou
        return self.matrix_iou

    def _bce_iou_loss(self):
        y_pred = self.pred_y
        y_true = self.y
        pred_flat = y_pred.view(y_pred.numel())
        true_flat = y_true.view(y_true.numel())

        intersection = torch.sum(pred_flat * true_flat) + 1e-7
        denominator = torch.sum(pred_flat + true_flat) - intersection + 1e-7
        iou = torch.div(intersection, denominator)
        bce_loss = self.bce_loss(pred_flat, true_flat)
        self.loss = bce_loss - iou + 1
        self.loss_stack += self.loss
        
    def get_current_losses(self):
        errors_ret = {}
        for name in self.display_names:
            if isinstance(name, str):
                errors_ret[name] = float(getattr(self, name)) / self.stack_count
        self.loss_stack = 0
        self.matrix_iou_stack = 0
        self.stack_count = 0
        return errors_ret
        
    def eval_iou(self):
        with torch.no_grad():
            self.forward()
            self._bce_iou_loss()
            _ = self.accu_iou()
            self.stack_count += 1

其他的代码就是很固定的pytorch模板代码了。

代码参考自GitHub


转载请标明出处


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章