語義分割之《CCNet: Criss-Cross Attention for Semantic Segmentation》論文閱讀筆記

  論文地址:CCNet: Criss-Cross Attention for Semantic Segmentation

  代碼地址:CCNet github

一、簡介

  CCNet是2018年11月發佈的一篇語義分割方面的文章中提到的網絡,該網絡有三個優勢:

  • GPU內存友好;
  • 計算高效;
  • 性能好。

  CCNet之前的論文比如FCNs只能管制局部特徵和少部分的上下文信息,空洞卷積只能夠集中於當前像素而無法生成密集的上下文信息,雖然PSANet能夠生成密集的像素級的上下文信息但是計算效率過低,其計算複雜度高達O((H*W)*(H*\W))。因此可以明顯的看出,CCNet的目的是高效的生成密集的像素級的上下文信息。
  Cirss-Cross Attention Block的參數對比如下圖所示:
在這裏插入圖片描述
  CCNet論文的主要貢獻:

  • 提出了Cirss-Cross Attention Module;
  • 提出了高效利用Cirss-Cross Attention Module的CCNet。

二、結構

1、CCNet結構

  CCNet的網絡結構如下圖所示:
在這裏插入圖片描述
  CCNet的基本結構描述如下:

  • 1、圖像通過特徵提取網絡得到feature map的大小爲HWH*W,爲了更高效的獲取密集的特徵圖,將原來的特徵提取網絡中的後面兩個下采樣去除,替換爲空洞卷積,使得feature map的大小爲輸入圖像的1/8;
  • 2、feature map X分爲兩個分支,分別進入3和4;
  • 3、一個分支先將X進行通道縮減壓縮特徵,然後通過兩個CCA(Cirss-Cross Attention)模塊,兩個模塊共享相同的參數,得到特徵HH^{''}
  • 4、另一個分支保持不變爲X;
  • 5、將3和4兩個分支的特徵融合到一起最終經過upsample得到分割圖像。

2、Criss-Cross Attention

 Criss-Cross Attention模塊的結構如下所示,輸入feature爲HRCWHH\in \mathbb{R}^{C*W*H},HH分爲Q,K,VQ,K,V三個分支,都通過1*1的卷積網絡的進行降維得到Q,KRCWH{Q,K}\in \mathbb{R}^{C^{'}*W*H}C<CC^{'}<C)。其中Attention Map AR(H+W1)WHA\in \mathbb{R}^{(H+W-1)*W*H}QQKK通過Affinity操作計算的。Affinity操作定義爲:
di,u=QuΩi,uT d_{i,u}=Q_u\Omega_{i,u}^{T}
  其中QuRCQ_u\in\mathbb{R}^{C^{'}}是在特徵圖Q的空間維度上的u位置的值。ΩuR(H+W1)C\Omega_u\in\mathbb{R}^{(H+W-1)C^{'}}KKuu位置處的同列和同行的元素的集合。因此,Ωu,iRC\Omega_{u,i}\in\mathbb{R}^{C^{'}}Ωu\Omega_u中的第ii個元素,其中i=[1,2,...,Ωu]i=[1,2,...,|\Omega_u|]。而di,uDd_{i,u}\in D表示QuQ_uΩi,u\Omega_{i,u}之間的聯繫的權重,DR(H+W1)WHD\in \mathbb{R}^{(H+W-1)*W*H}。最後對DD進行在通道維度上繼續進行softmax操作計算Attention Map AA
  另一個分支VV經過一個1*1卷積層得到VRCWHV \in \mathbb{R}^{C*W*H}的適應性特徵。同樣定義VuRCV_u \in \mathbb{R}^CΦuR(H+W1)C\Phi_u\in \mathbb{R}^{(H+W-1)*C}Φu\Phi_uVV上u點的同行同列的集合,則定義Aggregation操作爲:
HuiΦuAi,uΦi,u+Hu H_u^{'}\sum_{i \in |\Phi_u|}{A_{i,u}\Phi_{i,u}+H_u}
  該操作在保留原有feature的同時使用經過attention處理過的feature來保全feature的語義性質。
在這裏插入圖片描述

3、Recurrent Criss-Cross Attention

  單個Criss-Cross Attention模塊能夠提取更好的上下文信息,但是下圖所示,根據criss-cross attention模塊的計算方式左邊右上角藍色的點只能夠計算到和其同列同行的關聯關係,也就是說相應的語義信息的傳播無法到達左下角的點,因此再添加一個Criss-Cross Attention模塊可以將該語義信息傳遞到之前無法傳遞到的點。
在這裏插入圖片描述
  採用Recurrent Criss-Cross Attention之後,先定義loop=2,第一個loop的attention map爲AA,第二個loop的attention map爲AA^{'},從原feature上位置x,yx^{'},y^{'}到權重Ai,x,yA_{i,x,y}的映射函數爲Ai,x,y=f(A,x,y,x,y)A_{i,x,y}=f(A,x,y,x^{'},y^{'}),feature HH中的位置用θ\theta表示,feature中HH^{''}uu表示,如果uuθ\theta相同則:
Hu[f(A,u,θ)+1]f(A,u,θ)Hθ H_u^{''}\leftarrow[f(A,u,\theta)+1]\cdot f(A^{'},u,\theta)\cdot H_{\theta}
  其中\leftarrow表示加到操作,如果uuθ\theta不同則:
Hu[f(A,ux,θy,θx,θy)f(A,ux,uy,ux,θy)+f(A,θx,uy,θx,θy)f(A,ux,uy,θx,θy)]Hθ H_u^{''}\leftarrow[f(A,u_x,\theta_{y}, \theta_{x}, \theta_{y})\cdot f(A^{'},u_x,u_{y}, u_{x}, \theta_{y})+f(A,\theta_x,u_{y}, \theta_{x}, \theta_{y})\cdot f(A^{'},u_x,u_{y}, \theta_{x}, \theta_{y})]\cdot H_{\theta}
  Cirss-Cross Attention模塊可以應用於多種任務不僅僅是語義分割,作者同樣在多種任務中使用了該模塊,可以參考論文。

4、代碼

  下面是Cirss-Cross Attention模塊的代碼可以看到ca_weight便是Affinity操作,ca_map便是Aggregation操作。

class CrissCrossAttention(nn.Module):
    """ Criss-Cross Attention Module"""
    def __init__(self,in_dim):
        super(CrissCrossAttention,self).__init__()
        self.chanel_in = in_dim

        self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self,x):
        proj_query = self.query_conv(x)
        proj_key = self.key_conv(x)
        proj_value = self.value_conv(x)

        energy = ca_weight(proj_query, proj_key)
        attention = F.softmax(energy, 1)
        out = ca_map(attention, proj_value)
        out = self.gamma*out + x

        return out

  Affinity操作定義如下:

class CA_Weight(autograd.Function):
    @staticmethod
    def forward(ctx, t, f):
        # Save context
        n, c, h, w = t.size()
        size = (n, h+w-1, h, w)
        weight = torch.zeros(size, dtype=t.dtype, layout=t.layout, device=t.device)

        _ext.ca_forward_cuda(t, f, weight)
        
        # Output
        ctx.save_for_backward(t, f)

        return weight

    @staticmethod
    @once_differentiable
    def backward(ctx, dw):
        t, f = ctx.saved_tensors

        dt = torch.zeros_like(t)
        df = torch.zeros_like(f)

        _ext.ca_backward_cuda(dw.contiguous(), t, f, dt, df)

        _check_contiguous(dt, df)

        return dt, df

  Aggregation操作定義如下:

class CA_Map(autograd.Function):
    @staticmethod
    def forward(ctx, weight, g):
        # Save context
        out = torch.zeros_like(g)
        _ext.ca_map_forward_cuda(weight, g, out)
        
        # Output
        ctx.save_for_backward(weight, g)

        return out

    @staticmethod
    @once_differentiable
    def backward(ctx, dout):
        weight, g = ctx.saved_tensors

        dw = torch.zeros_like(weight)
        dg = torch.zeros_like(g)

        _ext.ca_map_backward_cuda(dout.contiguous(), weight, g, dw, dg)

        _check_contiguous(dw, dg)

        return dw, dg

  其中使用ext是c庫文件:
在這裏插入圖片描述
  RCC模塊的實現如下所示:

class RCCAModule(nn.Module):
    def __init__(self, in_channels, out_channels, num_classes):
        super(RCCAModule, self).__init__()
        inter_channels = in_channels // 4
        self.conva = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
                                   InPlaceABNSync(inter_channels))
        self.cca = CrissCrossAttention(inter_channels)
        self.convb = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
                                   InPlaceABNSync(inter_channels))

        self.bottleneck = nn.Sequential(
            nn.Conv2d(in_channels+inter_channels, out_channels, kernel_size=3, padding=1, dilation=1, bias=False),
            InPlaceABNSync(out_channels),
            nn.Dropout2d(0.1),
            nn.Conv2d(512, num_classes, kernel_size=1, stride=1, padding=0, bias=True)
            )

    def forward(self, x, recurrence=1):
        output = self.conva(x)
        for i in range(recurrence):
            output = self.cca(output)
        output = self.convb(output)

        output = self.bottleneck(torch.cat([x, output], 1))
        return output

  CCNet的整體結構:

class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes):
        self.inplanes = 128
        super(ResNet, self).__init__()
        self.conv1 = conv3x3(3, 64, stride=2)
        self.bn1 = BatchNorm2d(64)
        self.relu1 = nn.ReLU(inplace=False)
        self.conv2 = conv3x3(64, 64)
        self.bn2 = BatchNorm2d(64)
        self.relu2 = nn.ReLU(inplace=False)
        self.conv3 = conv3x3(64, 128)
        self.bn3 = BatchNorm2d(128)
        self.relu3 = nn.ReLU(inplace=False)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.relu = nn.ReLU(inplace=False)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True) # change
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, multi_grid=(1,1,1))
        #self.layer5 = PSPModule(2048, 512)
        self.head = RCCAModule(2048, 512, num_classes)

        self.dsn = nn.Sequential(
            nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1),
            InPlaceABNSync(512),
            nn.Dropout2d(0.1),
            nn.Conv2d(512, num_classes, kernel_size=1, stride=1, padding=0, bias=True)
            )

    def forward(self, x, recurrence=1):
    x = self.relu1(self.bn1(self.conv1(x)))
    x = self.relu2(self.bn2(self.conv2(x)))
    x = self.relu3(self.bn3(self.conv3(x)))
    x = self.maxpool(x)
    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x_dsn = self.dsn(x)
    x = self.layer4(x)
    x = self.head(x, recurrence)
    return [x, x_dsn]

三、結果

  與主流的方法的比較:
在這裏插入圖片描述
  下面是不同loop時的效果可以看到loop=2時的效果要比loop=2好。下面是不同loop的attention map。
在這裏插入圖片描述
在這裏插入圖片描述
在這裏插入圖片描述
在這裏插入圖片描述
在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章