CeiT:訓練更快的多層特徵抽取ViT

【GiantPandaCV導語】來自商湯和南洋理工的工作,也是使用卷積來增強模型提出low-level特徵的能力,增強模型獲取局部性的能力,核心貢獻是LCA模塊,可以用於捕獲多層特徵表示。

引言

針對先前Transformer架構需要大量額外數據或者額外的監督(Deit),才能獲得與卷積神經網絡結構相當的性能,爲了克服這種缺陷,提出結合CNN來彌補Transformer的缺陷,提出了CeiT:

(1)設計Image-to-Tokens模塊來從low-level特徵中得到embedding。

(2)將Transformer中的Feed Forward模塊替換爲Locally-enhanced Feed-Forward(LeFF)模塊,增加了相鄰token之間的相關性。

(3)使用Layer-wise Class Token Attention(LCA)捕獲多層的特徵表示。

經過以上修改,可以發現模型效率方面以及泛化能力得到了提升,收斂性也有所改善,如下圖所示:

方法

1. Image-to-Tokens

使用卷積+池化來取代原先ViT中7x7的大型patch。

\[\mathbf{x}^{\prime}=\mathrm{I} 2 \mathrm{~T}(\mathbf{x})=\operatorname{MaxPool}(\operatorname{BN}(\operatorname{Conv}(\mathbf{x}))) \]

2. LeFF

將tokens重新拼成feature map,然後使用深度可分離卷積添加局部性的處理,然後再使用一個Linear層映射至tokens。

\[\begin{aligned} \mathbf{x}_{c}^{h}, \mathbf{x}_{p}^{h} &=\operatorname{Split}\left(\mathbf{x}_{t}^{h}\right) \\ \mathbf{x}_{p}^{l_{1}} &=\operatorname{GELU}\left(\operatorname{BN}\left(\operatorname{Linear}\left(\left(\mathbf{x}_{p}^{h}\right)\right)\right)\right.\\ \mathbf{x}_{p}^{s} &=\operatorname{SpatialRestore}\left(\mathbf{x}_{p}^{l_{1}}\right) \\ \mathbf{x}_{p}^{d} &=\operatorname{GELU}\left(\operatorname{BN}\left(\operatorname{DWConv}\left(\mathbf{x}_{p}^{s}\right)\right)\right) \\ \mathbf{x}_{p}^{f} &=\operatorname{Flatten}\left(\mathbf{x}_{p}^{d}\right) \\ \mathbf{x}_{p}^{l_{2}} &=\operatorname{GELU}\left(\operatorname{BN}\left(\operatorname{Linear} 2\left(\mathbf{x}_{p}^{f}\right)\right)\right) \\ \mathbf{x}_{t}^{h+1} &=\operatorname{Concat}\left(\mathbf{x}_{c}^{h}, \mathbf{x}_{p}^{l_{2}}\right) \end{aligned} \]

3. LCA

前兩個都比較常規,最後一個比較有特色,經過所有Transformer層以後使用的Layer-wise Class-token Attention,如下圖所示:

LCA模塊會將所有Transformer Block中得到的class token作爲輸入,然後再在其基礎上使用一個MSA+FFN得到最終的logits輸出。作者認爲這樣可以獲取多尺度的表徵。

實驗

SOTA比較:

I2T消融實驗:

LeFF消融實驗:

LCA有效性比較:

收斂速度比較:

代碼

模塊1:I2T Image-to-Token

  # IoT
  self.conv = nn.Sequential(
      nn.Conv2d(in_channels, out_channels, conv_kernel, stride, 4),
      nn.BatchNorm2d(out_channels),
      nn.MaxPool2d(pool_kernel, stride)    
  )
  
  feature_size = image_size // 4

  assert feature_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
  num_patches = (feature_size // patch_size) ** 2
  patch_dim = out_channels * patch_size ** 2
  self.to_patch_embedding = nn.Sequential(
      Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
      nn.Linear(patch_dim, dim),
  )

模塊2:LeFF

class LeFF(nn.Module):
    
    def __init__(self, dim = 192, scale = 4, depth_kernel = 3):
        super().__init__()
        
        scale_dim = dim*scale
        self.up_proj = nn.Sequential(nn.Linear(dim, scale_dim),
                                    Rearrange('b n c -> b c n'),
                                    nn.BatchNorm1d(scale_dim),
                                    nn.GELU(),
                                    Rearrange('b c (h w) -> b c h w', h=14, w=14)
                                    )
        
        self.depth_conv =  nn.Sequential(nn.Conv2d(scale_dim, scale_dim, kernel_size=depth_kernel, padding=1, groups=scale_dim, bias=False),
                          nn.BatchNorm2d(scale_dim),
                          nn.GELU(),
                          Rearrange('b c h w -> b (h w) c', h=14, w=14)
                          )
        
        self.down_proj = nn.Sequential(nn.Linear(scale_dim, dim),
                                    Rearrange('b n c -> b c n'),
                                    nn.BatchNorm1d(dim),
                                    nn.GELU(),
                                    Rearrange('b c n -> b n c')
                                    )
        
    def forward(self, x):
        x = self.up_proj(x)
        x = self.depth_conv(x)
        x = self.down_proj(x)
        return x
        
class TransformerLeFF(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, scale = 4, depth_kernel = 3, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
                Residual(PreNorm(dim, LeFF(dim, scale, depth_kernel)))
            ]))
    def forward(self, x):
        c = list()
        for attn, leff in self.layers:
            x = attn(x)
            cls_tokens = x[:, 0]
            c.append(cls_tokens)
            x = leff(x[:, 1:])
            x = torch.cat((cls_tokens.unsqueeze(1), x), dim=1) 
        return x, torch.stack(c).transpose(0, 1)

模塊3:LCA

class LCAttention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
        q = q[:, :, -1, :].unsqueeze(2) # Only Lth element use as query

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = dots.softmax(dim=-1)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out =  self.to_out(out)
        return out

class LCA(nn.Module):
    # I remove Residual connection from here, in paper author didn't explicitly mentioned to use Residual connection, 
    # so I removed it, althougth with Residual connection also this code will work.
    def __init__(self, dim, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.layers.append(nn.ModuleList([
                PreNorm(dim, LCAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x[:, -1].unsqueeze(1)
            x = x[:, -1].unsqueeze(1) + ff(x)
        return x

參考

https://arxiv.org/abs/2103.11816

https://github.com/rishikksh20/CeiT-pytorch/blob/master/ceit.py

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章