不定期讀一篇Paper之BAM

不定期讀一篇Paper之BAM

前言

​   相較於CBAM模塊的次序鏈接,通道注意力在前,空間注意力在後的連接方式,BAM使用了“並聯”式的結構,使得通道注意力和空間注意力相互作用,實現了網絡的“what”和“where”,並且它也是一個模塊化的結構,可以嵌入任何CNN的網絡中。

框架

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-hKgwoAej-1585816329938)(./imgs/module.png)]

空間Attention

​   空間attention模塊產生一個空間注意力圖來強調和抑制空間中不同位置的特徵。空間attention中也使用了空洞卷積來增加感受的區域,從而獲得更多的空間信息。計算公式如下:
MS(F)=BN(f311(f233(f133(f011(F)))))(1) M_S(F) = BN(f_3^{1*1}(f_2^{3*3}(f_1^{3*3}(f_0^{1*1}(F)))))\tag{1}

Pytorch代碼

class SpatialGate(nn.Module):
    def __init__(self, gate_channel, reduction_ratio=16, dilation_conv_num = 2, dilation_val=4):
        super(SpatialGate, self).__init__()
        self.gate_s = nn.Sequential()
        
        # 減少通道,減少參數
        self.gate_s.add_module("gate_s_conv_resuce0", nn.Conv2d(gate_channel, gate_channel//reduction_ratio, kernel_size=1))
        self.gate_s.add_module("gate_s_relu_reduce0", nn.ReLU())
        
        # 空洞卷積,增大感受野
        for i in range(dilation_conv_num):
            self.gate_s.add_module("gate_s_conv_di_%d"%i, nn.Conv2d(gate_channel//reduction_ratio, gate_channel//reduction_ratio, kernel_size=3
                                       , padding=dilation_val, dilation = dilation_val))
            self.gate_s.add_module("gate_s_bn_di_%d"%i
                      ,nn.BatchNorm2d(gate_channel//reduction_ratio))
            self.gate_s.add_module("gate_s_relu_di_%d"%i, nn.ReLU())

        # 通道變爲1  
        self.gate_s.add_module("gate_s_conv_final", 															nn.Conv2d(gate_channel//reduction_ratio, 1, kernel_size=1))
        
    def forward(self, x):
        return self.gate_s(x).expand_as(x)     

可視化結果

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1          [1, 16, 128, 128]           4,112
              ReLU-2          [1, 16, 128, 128]               0
            Conv2d-3          [1, 16, 128, 128]           2,320
       BatchNorm2d-4          [1, 16, 128, 128]              32
              ReLU-5          [1, 16, 128, 128]               0
            Conv2d-6          [1, 16, 128, 128]           2,320
       BatchNorm2d-7          [1, 16, 128, 128]              32
              ReLU-8          [1, 16, 128, 128]               0
            Conv2d-9           [1, 1, 128, 128]              17
================================================================

通道Attention

​   由於不同的通道之間有的特徵響應,所以通道attention用來探究通道間的相互聯繫,計算公式如下:
MC(F)=BN(MLP(AvgPool(F))) M_C(F) = BN(MLP(AvgPool(F)))

Pytorch代碼

class ChannelGate(nn.Module):
    """通道Attention"""
    
    def __init__(self, gate_channel, reduction_ration=16, num_layers=1):
        super(ChannelGate, self).__init__()
        self.gate_c = nn.Sequential()
        
        # 平均池化之後
        self.gate_c.add_module("flatten", Flatten())
        
        gate_channels = [gate_channel]
        gate_channels += [gate_channel//reduction_ration] * num_layers
        gate_channels += [gate_channel]
        
        # 添加MLP
        for i in range(len(gate_channels) - 2):
            self.gate_c.add_module("gate_c_fc_%d"%i, nn.Linear(gate_channels[i], gate_channels[i+1]))
            self.gate_c.add_module("gate_c_bn_%d"%(i+1), nn.BatchNorm1d(gate_channels[i+1]))
            self.gate_c.add_module("gate_c_relu_%d"%(i+1), nn.ReLU())
        
        self.gate_c.add_module("gate_c_fc_final", nn.Linear(gate_channels[-2], gate_channels[-1]))
        
    def forward(self, x):
        avg_pool = F.avg_pool2d(x, x.size(2), stride=x.size(2))
        # unseueeze()增加維度
        return self.gate_c(avg_pool).unsqueeze(2).unsqueeze(3).expand_as(x) 

可視化結果

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
           Flatten-1                   [1, 256]               0
            Linear-2                    [1, 16]           4,112
       BatchNorm1d-3                    [1, 16]              32
              ReLU-4                    [1, 16]               0
            Linear-5                   [1, 256]           4,352
================================================================

融合

class BAM(nn.Module):
    def __init__(self, gate_channel):
        super(BAM, self).__init__()
        self.channel_att = ChannelGate(gate_channel)
        self.spatial_att = SpatialGate(gate_channel)
        
    def forward(self, x):
        # 點乘
        att = 1 + t.sigmoid(self.channel_att(x)*self.spatial_att(x))
        return att * x

可視化結果

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
           Flatten-1                   [1, 256]               0
            Linear-2                    [1, 16]           4,112
       BatchNorm1d-3                    [1, 16]              32
              ReLU-4                    [1, 16]               0
            Linear-5                   [1, 256]           4,352
       ChannelGate-6         [1, 256, 128, 128]               0
            Conv2d-7          [1, 16, 128, 128]           4,112
              ReLU-8          [1, 16, 128, 128]               0
            Conv2d-9          [1, 16, 128, 128]           2,320
      BatchNorm2d-10          [1, 16, 128, 128]              32
             ReLU-11          [1, 16, 128, 128]               0
           Conv2d-12          [1, 16, 128, 128]           2,320
      BatchNorm2d-13          [1, 16, 128, 128]              32
             ReLU-14          [1, 16, 128, 128]               0
           Conv2d-15           [1, 1, 128, 128]              17
      SpatialGate-16         [1, 256, 128, 128]               0
================================================================

參考

官方代碼

關於感受野的總結

如何理解Dilated Convolutions(空洞卷積)

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章