CoAtNet: 90.88% Paperwithcode榜單第一,層層深入考慮模型設計

【GiantPandaCV導語】CoAt=Convolution + Attention,paperwithcode榜單第一名,通過結合卷積與Transformer實現性能上的突破,方法部分設計非常規整,層層深入考慮模型的架構設計。

引言

Transformer模型的容量大,由於缺乏正確的歸納偏置,泛化能力要比卷積網絡差。

提出了CoAtNets模型族:

  • 深度可分離卷積與self-attention能夠通過簡單的相對注意力來統一化。
  • 疊加捲積層和注意層在提高泛化能力和效率方面具有驚人的效果

方法

這部分主要關注如何將conv與transformer以一種最優的方式結合:

  • 在基礎的計算塊中,如果合併卷積與自注意力操作。
  • 如何組織不同的計算模塊來構建整個網絡。

合併卷積與自注意力

卷積方面谷歌使用的是經典的MBConv, 使用深度可分離卷積來捕獲空間之間的交互。

卷積操作的表示:\(\mathcal{L}(i)\)代表i周邊的位置,也即卷積處理的感受野。

\[y_{i}=\sum_{j \in \mathcal{L}(i)} w_{i-j} \odot x_{j} \quad \text { (depthwise convolution) } \]

自注意力表示:\(\mathcal{G}\)表示全局空間感受野。

\[y_{i}=\sum_{j \in \mathcal{G}} \underbrace{\frac{\exp \left(x_{i}^{\top} x_{j}\right)}{\sum_{k \in \mathcal{G}} \exp \left(x_{i}^{\top} x_{k}\right)}}_{A_{i, j}} x_{j} \quad \text { (self-attention) } \]

融合方法一:先求和,再softmax

\[y_{i}^{\text {post }}=\sum_{j \in \mathcal{G}}\left(\frac{\exp \left(x_{i}^{\top} x_{j}\right)}{\sum_{k \in \mathcal{G}} \exp \left(x_{i}^{\top} x_{k}\right)}+w_{i-j}\right) x_{j} \]

融合方法二:先softmax,再求和

\[y_{i}^{\text {pre }}=\sum_{j \in \mathcal{G}} \frac{\exp \left(x_{i}^{\top} x_{j}+w_{i-j}\right)}{\sum_{k \in \mathcal{G}} \exp \left(x_{i}^{\top} x_{k}+w_{i-k}\right)} x_{j} \]

出於參數量、計算兩方面的考慮,論文打算採用第二種融合方法。

垂直佈局設計

決定好合並卷積與注意力的方式後應該考慮如何構建網絡整體架構,主要有三個方面的考量:

  • 使用降採樣降低空間維度大小,然後使用global relative attention。
  • 使用局部注意力,強制全局感受野限制在一定範圍內。典型代表有:
    • Scaling local self-attention for parameter efficient visual backbone
    • Swin Transformer
  • 使用某種線性注意力方法來取代二次的softmax attention。典型代表有:
    • Efficient Attention
    • Transformers are rnns
    • Rethinking attention with performers

第二種方法實現效率不夠高,第三種方法性能不夠好,因此採用第一種方法,如何設計降採樣的方式也有幾種方案:

  • 使用卷積配合stride進行降採樣。
  • 使用pooling操作完成降採樣,構建multi-stage網絡範式。
  • 根據第一種方案提出\(ViT_{REL}\), 即使用ViT Stem,直接堆疊L層Transformer block使用relative attention。
  • 根據第二種方案,採用multi-stage方案提出模型組:\(S_0,...,S_4\),如下圖所示:

\(S_o-S_2\)採用卷積以及MBConv,從\(S_2-S_4\)的幾個模塊採用Transformer 結構。具體Transformer內部有以下幾個變體:C代表卷積,T代表Transformer

  • C-C-C-C
  • C-C-C-T
  • C-C-T-T
  • C-T-T-T

初步測試模型泛化能力

泛化能力排序爲:(證明架構中還是需要存在想當比例的卷積操作)

初步測試模型容量

主要是從JFT以及ImageNet-1k上不同的表現來判定的,排序結果爲:

測試模型遷移能力

爲了進一步比較CCTT與CTTT,進行了遷移能力測試,發現CCTT能夠超越CTTT。

最終CCTT勝出!

實驗

與SOTA模型比較結果:

實驗結果:

消融實驗:

代碼

淺層使用的MBConv模塊如下:

class MBConv(nn.Module):
    def __init__(self, inp, oup, image_size, downsample=False, expansion=4):
        super().__init__()
        self.downsample = downsample
        stride = 1 if self.downsample == False else 2
        hidden_dim = int(inp * expansion)

        if self.downsample:
            self.pool = nn.MaxPool2d(3, 2, 1)
            self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)

        if expansion == 1:
            self.conv = nn.Sequential(
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
                          1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                # pw
                # down-sample in the first conv
                nn.Conv2d(inp, hidden_dim, 1, stride, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1,
                          groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                SE(inp, hidden_dim),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        
        self.conv = PreNorm(inp, self.conv, nn.BatchNorm2d)

    def forward(self, x):
        if self.downsample:
            return self.proj(self.pool(x)) + self.conv(x)
        else:
            return x + self.conv(x)

主要關注Attention Block設計,引入Relative Position:

class Attention(nn.Module):
    def __init__(self, inp, oup, image_size, heads=8, dim_head=32, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == inp)

        self.ih, self.iw = image_size

        self.heads = heads
        self.scale = dim_head ** -0.5

        # parameter table of relative position bias
        self.relative_bias_table = nn.Parameter(
            torch.zeros((2 * self.ih - 1) * (2 * self.iw - 1), heads))

        coords = torch.meshgrid((torch.arange(self.ih), torch.arange(self.iw)))
        coords = torch.flatten(torch.stack(coords), 1)
        relative_coords = coords[:, :, None] - coords[:, None, :]

        relative_coords[0] += self.ih - 1
        relative_coords[1] += self.iw - 1
        relative_coords[0] *= 2 * self.iw - 1
        relative_coords = rearrange(relative_coords, 'c h w -> h w c')
        relative_index = relative_coords.sum(-1).flatten().unsqueeze(1)
        self.register_buffer("relative_index", relative_index)

        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(inp, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, oup),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(
            t, 'b n (h d) -> b h n d', h=self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        # Use "gather" for more efficiency on GPUs
        relative_bias = self.relative_bias_table.gather(
            0, self.relative_index.repeat(1, self.heads))
        relative_bias = rearrange(
            relative_bias, '(h w) c -> 1 c h w', h=self.ih*self.iw, w=self.ih*self.iw)
        dots = dots + relative_bias

        attn = self.attend(dots)
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)
        return out

參考

https://arxiv.org/pdf/2106.04803.pdf

https://github.com/chinhsuanwu/coatnet-pytorch

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章