CBAM: Convolutional Block Attention Module

1.介紹

  CBAM的中文名字是基於卷積塊的注意機制,從結構上來看,它結合了空間注意力機制和通道注意力機制,從效果上來看,它能提高分類和檢測的正確率。更加詳細的內容可以參加論文:CBAM: Convolutional Block Attention Module

  

2.模型結構

     對於以下結構,我想用比較具體的數字來表示整個流程,比如輸入的Feature維度爲1*16*10*10(各維度代表BCHW),經過通道注意力機制之後,得到的權重維度爲1*16*1*1(權重值越大,代表着那個通道特徵圖越重要),經過空間注意力機制之後,得到的權重爲1*1*10*10(權重越大,代表特徵圖上的內容越重要)

    通道注意力機制:在下圖中,MaxPool的操作就是提取一副特徵圖上的最大值,有多少通道就提取多少個;AvgPool的操作就是提取一副特徵圖上的平均值,也是有多少通道就提取多少個;MLP是一個共享(共用,只有一個)的全連接操作,後面的操作比較簡單,就不介紹了。

    空間注意力機制:在下圖中,MaxPool的操作就是在通道上提取最大值,提取的次數是高乘以寬;AvgPool的操作就是在通道上提取平均值,提取的次數也是是高乘以寬;接着將前面所提取到的特徵圖(通道數都爲1)合併得到一個2通道的特徵圖,後面的操作比較簡單,就不介紹了。

 

3.模型特點

      1.在通道注意力機制中引入全連接,並通過全連接降維,有利於提取更重要的信息(相當於PCA操作)

      2.一般注意力機制是用AvgPool操作,本文引入MaxPool操作,可以提取到不同信息,增強特徵的多樣性。

 

 4.代碼實現Pytorch

# -*-coding:utf-8-*-
import math

import torch
import torch.nn as nn
import torch.nn.functional as F


class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)


class ChannelGate(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16):
        super(ChannelGate, self).__init__()
        self.gate_channels = gate_channels
        self.mlp = nn.Sequential(
            Flatten(),
            nn.Linear(gate_channels, gate_channels // reduction_ratio),
            nn.ReLU(),
            nn.Linear(gate_channels // reduction_ratio, gate_channels)
        )

    def forward(self, x):
        avg_pool = F.avg_pool2d(
            x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
        channel_avg = self.mlp(avg_pool)
        max_pool = F.max_pool2d(
            x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
        channel_max = self.mlp(max_pool)

        channel_att_sum = channel_avg + channel_max

        scale = torch.sigmoid(channel_att_sum).unsqueeze(
            2).unsqueeze(3).expand_as(x)
        return x * scale


class ChannelPool(nn.Module):
    def forward(self, x):
        return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)


class SpatialGate(nn.Module):
    def __init__(self):
        super(SpatialGate, self).__init__()
        self.compress = ChannelPool()
        self.spatial = torch.nn.Conv2d(2, 1, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))

    def forward(self, x):
        x_compress = self.compress(x)
        x_out = self.spatial(x_compress)
        scale = torch.sigmoid(x_out)  # broadcasting
        return x * scale


class CBAM(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
        super(CBAM, self).__init__()
        self.ChannelGate = ChannelGate(
            gate_channels, reduction_ratio, pool_types)
        self.SpatialGate = SpatialGate()

    def forward(self, x):
        x_out = self.ChannelGate(x)
        x_out = self.SpatialGate(x_out)
        return x_out

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章