通道注意力改進ECANet

 

  • 論文鏈接:https://arxiv.org/abs/1910.03151

  • 代碼地址:https://github.com/BangguWu/ECANet

ECANet主要對SENet模塊進行了一些改進,提出了一種不降維的局部跨信道交互策略(ECA模塊)和自適應選擇一維卷積核大小的方法,從而實現了性能上的提優。最近已經有很多文章在通道和空間注意力上做改進並取得了性能提升。例如SKNet,SANet,ResNeSt等等,不得不說,注意力機制真的香!

因此,足者提出了一種不降維的局部跨信道交互策略,該策略可以通過一維卷積有效地實現。進一步,作者又提出了一種自適應選擇一維卷積核大小的方法,以確定局部跨信道交互的覆蓋率。

具體來說,在給定輸入特徵的情況下,SE塊首先對每個通道單獨使用全局平均池化,然後使用兩個具有非線性的完全連接(FC)層,然後使用一個Sigmoid函數來生成通道權值。兩個FC層的設計是爲了捕捉非線性的跨通道交互,其中包括降維來控制模型的複雜性。雖然該策略在後續的通道注意模塊中得到了廣泛的應用,但作者的實驗研究表明,降維對通道注意預測帶來了副作用,捕獲所有通道之間的依賴是低效的,也是不必要的。

在不降低維數的通道級全局平均池化之後,ECA通過考慮每個通道及其k個鄰居來捕獲局部跨通道交互信息。實踐證明,該方法保證了模型效率和計算效果。需要注意的是,ECA可以通過大小爲k的快速1D卷積來有效實現,其中卷積核大小爲k代表了局部跨信道交互的覆蓋率,即,該通道附近有多少鄰居參與了這個信道的注意力預測,爲了避免通過交叉驗證對k進行手動調優,本文提出了一種方法來自適應地確定k,其中交互的覆蓋率(即卷積核大小 k)與通道維數成正比。

import torch
from torch import nn
from torch.nn.parameter import Parameter

class eca_layer(nn.Module):
    """Constructs a ECA module.
    Args:
        channel: Number of channels of the input feature map
        k_size: Adaptive selection of kernel size
    """
    def __init__(self, channel, k_size=3):
        super(eca_layer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) 
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # x: input features with shape [b, c, h, w]
        b, c, h, w = x.size()

        # feature descriptor on the global spatial information
        y = self.avg_pool(x)

        # Two different branches of ECA module
        y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)

        # Multi-scale information fusion
        y = self.sigmoid(y)

        return x * y.expand_as(x)

 

from torch import nn
from .eca_module import eca_layer

__all__ = ['ECA_MobileNetV2', 'eca_mobilenet_v2']


model_urls = {
    'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
}


class ConvBNReLU(nn.Sequential):
    def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
        padding = (kernel_size - 1) // 2
        super(ConvBNReLU, self).__init__(
            nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
            nn.BatchNorm2d(out_planes),
            nn.ReLU6(inplace=True)
        )


class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, expand_ratio, k_size):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = int(round(inp * expand_ratio))
        self.use_res_connect = self.stride == 1 and inp == oup

        layers = []
        if expand_ratio != 1:
            # pw
            layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
        layers.extend([
            # dw
            ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
            # pw-linear
            nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
            nn.BatchNorm2d(oup),
        ])
        layers.append(eca_layer(oup, k_size))
        self.conv = nn.Sequential(*layers)

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)


class ECA_MobileNetV2(nn.Module):
    def __init__(self, num_classes=1000, width_mult=1.0):
        super(ECA_MobileNetV2, self).__init__()
        block = InvertedResidual
        input_channel = 32
        last_channel = 1280
        inverted_residual_setting = [
            # t, c, n, s
            [1, 16, 1, 1],
            [6, 24, 2, 2],
            [6, 32, 3, 2],
            [6, 64, 4, 2],
            [6, 96, 3, 1],
            [6, 160, 3, 2],
            [6, 320, 1, 1],
        ]

        # building first layer
        input_channel = int(input_channel * width_mult)
        self.last_channel = int(last_channel * max(1.0, width_mult))
        features = [ConvBNReLU(3, input_channel, stride=2)]
        # building inverted residual blocks
        for t, c, n, s in inverted_residual_setting:
            output_channel = int(c * width_mult)
            for i in range(n):
                if c <= 96:
                    ksize = 1
                else:
                    ksize = 3
                stride = s if i == 0 else 1
                features.append(block(input_channel, output_channel, stride, expand_ratio=t, k_size=ksize))
                input_channel = output_channel
        # building last several layers
        features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
        # make it nn.Sequential
        self.features = nn.Sequential(*features)

        # building classifier
        self.classifier = nn.Sequential(
            nn.Dropout(0.25),
            nn.Linear(self.last_channel, num_classes),
        )

        # weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x):
        x = self.features(x)
        x = x.mean(-1).mean(-1)
        x = self.classifier(x)
        return x


def eca_mobilenet_v2(pretrained=False, progress=True, **kwargs):
    """
    Constructs a ECA_MobileNetV2 architecture from
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    model = ECA_MobileNetV2(**kwargs)
    # if pretrained:
    #     state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
    #                                           progress=progress)
    #     model.load_state_dict(state_dict)
    return model

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章