CvT: 如何將卷積的優勢融入Transformer

【GiantPandaCV導語】與之前BoTNet不同,CvT雖然題目中有卷積的字樣,但是實際總體來說依然是以Transformer Block爲主的,在Token的處理方面引入了卷積,從而爲模型帶來的局部性。最終CvT最高拿下了87.7%的Top1準確率。

引言

CvT架構的Motivation也是將局部性引入Vision Transformer架構中,期望通過引入局部性得到更高的性能和效率權衡。因此我們主要關注CvT是如何引入局部性的。具體來說提出了兩點改進:

  • Convolutional token embedding
  • Convolutional Projection

通過以上改進,模型不僅具有卷積的優勢(局部感受野、權重共享、空間下采樣等特性帶來的優勢),如平移不變形、尺度不變性、旋轉不變性等,也保持了Self Attention的優勢,如動態注意力、全局語義信息、更強的泛化能力等。

展開一點講,Convolutional Vision Transformer有兩點核心:

  • 第一步,參考CNN的架構,將Transformer也設計爲多階段的層次架構,每個stage之前使用convolutional token embedding,通過使用卷積+layer normalization能夠實現降維的功能(注:逐漸降低序列長度的同時,增加每個token的維度,可以類比卷積中feature map砍半,通道數增加的操作)
  • 第二步,使用Convolutional Projection取代原來的Linear Projection,該模塊實際使用的是深度可分離卷積實現,這樣也能有效捕獲局部語義信息。

需要注意的是:CvT去掉了Positional Embedding模塊,發現對模型性能沒有任何影響。認爲可以簡化架構的設計,並且可以在分辨率變化的情況下更容易適配。

比較

在相關工作部分,CvT總結了一個表格,比較方便對比:

方法

在引言部分已經講得比較詳細了,下面對照架構圖覆盤一下(用盡可能通俗的語言描述):

  • 綠色框是conv token embedding操作,通俗來講,使用了超大卷積核來提升局部性不足的問題。
  • 右圖藍色框中展示的是改進的self attention,通俗來講,使用了non local的操作,使用深度可分離卷積取代MLP做Projection,如下圖所示:

  • 如圖(a)所示,Vision Transformer中使用的是MLP進行Linear Projection, 這樣的信息是全局性的,但是計算量比較大。
  • 如圖(b)所示,使用卷積進行映射,這種操作類似Non Local Network,使用卷積進行映射。
  • 如圖(c)所示,使用的是帶stride的卷積進行壓縮,這樣做是處於對效率的考量,token數量可以降低四倍,會帶來一定的性能損失。

Positional embedding探討:

由於Convolutional Projection在每個Transformer Block中都是用,配合Convolutional Token Embedding操作,能夠給模型足夠的能力來建模局部空間關係,因此可以去掉Transformer中的Positional Embedding操作。從下表發現,pe對模型性能影響不大。

與其他工作的對比:

  • 同期工作1:Tokens-to-Tokens ViT: 使用Progressive Tokenization整合臨近token,使用Transformer-based骨幹網絡具有局部性的同時,還能降低token序列長度。
  • 區別:CvT使用的是multi-stage的過程,token長度降低的同時,其維度在增加,從而保證模型的容量。同時計算量相比T2T有所改善。
  • 同期工作2:Pyramid Vision Transformer(PVT): 引入了金字塔架構,使得PVT可以作爲Backbone應用於Dense prediction任務中。
  • 區別:CvT也使用了金字塔架構,區別在於CvT中提出使用stride卷積來實現空間降採樣,進一步融合了局部信息。

最終模型架構如下:

實驗

左圖中令人感興趣的是BiT,這篇是谷歌的文章big transfer,探究CNN架構在大規模數據與訓練的效果,可以看出即便是純CNN架構模型參數量也可以非常巨大,而Vision Transformer還有CvT等在同等精度下模型參數量遠小於BiT,這一定程度上說明了Transformer結合CNN在數據量足夠的情況下性能可以非常可觀,要比單純CNN架構的模型性能更優。

右圖展示了CvT和幾種vision transformer架構的性能比較,可見CvT在權衡方面做的非常不錯。

與SOTA比較:

有趣的是CvT-13-NAS也採用了搜索的方法DA-NAS,主要搜索對象是key和value的stride,以及MLP的Expansion Ratio, 最終搜索的結果要比Baseline略好。

在無需JFT數據集的情況下,CvT最高調整可以達到87.7%的top1 準確率。

其他數據集結果:

消融實驗

代碼

Convolutional Token Embedding代碼實現:可以看出,實際上就是大卷積核+大Stride的滑動引入的局部性。

class ConvEmbed(nn.Module):
    """ Image to Conv Embedding
    """
    def __init__(self,
                 patch_size=7,
                 in_chans=3,
                 embed_dim=64,
                 stride=4,
                 padding=2,
                 norm_layer=None):
        super().__init__()
        patch_size = to_2tuple(patch_size)
        self.patch_size = patch_size

        self.proj = nn.Conv2d(
            in_chans, embed_dim,
            kernel_size=patch_size,
            stride=stride,
            padding=padding
        )
        self.norm = norm_layer(embed_dim) if norm_layer else None

    def forward(self, x):
        x = self.proj(x)

        B, C, H, W = x.shape
        x = rearrange(x, 'b c h w -> b (h w) c')
        if self.norm:
            x = self.norm(x)
        x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)

        return x

Convolutional Projection代碼實現,具體看_build_projection函數:

class Attention(nn.Module):
    def __init__(self,
                 dim_in,
                 dim_out,
                 num_heads,
                 qkv_bias=False,
                 attn_drop=0.,
                 proj_drop=0.,
                 method='dw_bn',
                 kernel_size=3,
                 stride_kv=1,
                 stride_q=1,
                 padding_kv=1,
                 padding_q=1,
                 with_cls_token=True,
                 **kwargs
                 ):
        super().__init__()
        self.stride_kv = stride_kv
        self.stride_q = stride_q
        self.dim = dim_out
        self.num_heads = num_heads
        # head_dim = self.qkv_dim // num_heads
        self.scale = dim_out ** -0.5
        self.with_cls_token = with_cls_token

        self.conv_proj_q = self._build_projection(
            dim_in, dim_out, kernel_size, padding_q,
            stride_q, 'linear' if method == 'avg' else method
        )
        self.conv_proj_k = self._build_projection(
            dim_in, dim_out, kernel_size, padding_kv,
            stride_kv, method
        )
        self.conv_proj_v = self._build_projection(
            dim_in, dim_out, kernel_size, padding_kv,
            stride_kv, method
        )

        self.proj_q = nn.Linear(dim_in, dim_out, bias=qkv_bias)
        self.proj_k = nn.Linear(dim_in, dim_out, bias=qkv_bias)
        self.proj_v = nn.Linear(dim_in, dim_out, bias=qkv_bias)

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim_out, dim_out)
        self.proj_drop = nn.Dropout(proj_drop)

    def _build_projection(self,
                          dim_in,
                          dim_out,
                          kernel_size,
                          padding,
                          stride,
                          method):
        if method == 'dw_bn':
            proj = nn.Sequential(OrderedDict([
                ('conv', nn.Conv2d(
                    dim_in,
                    dim_in,
                    kernel_size=kernel_size,
                    padding=padding,
                    stride=stride,
                    bias=False,
                    groups=dim_in
                )),
                ('bn', nn.BatchNorm2d(dim_in)),
                ('rearrage', Rearrange('b c h w -> b (h w) c')),
            ]))
        elif method == 'avg':
            proj = nn.Sequential(OrderedDict([
                ('avg', nn.AvgPool2d(
                    kernel_size=kernel_size,
                    padding=padding,
                    stride=stride,
                    ceil_mode=True
                )),
                ('rearrage', Rearrange('b c h w -> b (h w) c')),
            ]))
        elif method == 'linear':
            proj = None
        else:
            raise ValueError('Unknown method ({})'.format(method))

        return proj

    def forward_conv(self, x, h, w):
        if self.with_cls_token:
            cls_token, x = torch.split(x, [1, h*w], 1)

        x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)

        if self.conv_proj_q is not None:
            q = self.conv_proj_q(x)
        else:
            q = rearrange(x, 'b c h w -> b (h w) c')

        if self.conv_proj_k is not None:
            k = self.conv_proj_k(x)
        else:
            k = rearrange(x, 'b c h w -> b (h w) c')

        if self.conv_proj_v is not None:
            v = self.conv_proj_v(x)
        else:
            v = rearrange(x, 'b c h w -> b (h w) c')

        if self.with_cls_token:
            q = torch.cat((cls_token, q), dim=1)
            k = torch.cat((cls_token, k), dim=1)
            v = torch.cat((cls_token, v), dim=1)

        return q, k, v

    def forward(self, x, h, w):
        if (
            self.conv_proj_q is not None
            or self.conv_proj_k is not None
            or self.conv_proj_v is not None
        ):
            q, k, v = self.forward_conv(x, h, w)

        q = rearrange(self.proj_q(q), 'b t (h d) -> b h t d', h=self.num_heads)
        k = rearrange(self.proj_k(k), 'b t (h d) -> b h t d', h=self.num_heads)
        v = rearrange(self.proj_v(v), 'b t (h d) -> b h t d', h=self.num_heads)

        attn_score = torch.einsum('bhlk,bhtk->bhlt', [q, k]) * self.scale
        attn = F.softmax(attn_score, dim=-1)
        attn = self.attn_drop(attn)

        x = torch.einsum('bhlt,bhtv->bhlv', [attn, v])
        x = rearrange(x, 'b h t d -> b t (h d)')

        x = self.proj(x)
        x = self.proj_drop(x)

        return x

參考

https://github.com/microsoft/CvT/blob/main/lib/models/cls_cvt.py

https://arxiv.org/pdf/2103.15808.pdf

https://zhuanlan.zhihu.com/p/142864566

筆者在cifar10數據集上修改了CvT中的Stride等參數,在不用任何數據增強和Trick的情況下得到了下圖結果,Top1爲84.74。雖然看上去性能比較差,但是這還沒有調參以及加上數據增強方法,只訓練了200個epoch的結果。

python train.py --model 'cvt' --name "cvt" --sched 'cosine' --epochs 200 --lr 0.01

感興趣的可以點擊下面鏈接調參:

https://github.com/pprp/pytorch-cifar-model-zoo

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章