當可變形注意力機制引入Vision Transformer

【GiantPandaCV導語】通過在Transformer基礎上引入Deformable CNN中的可變性能力,在降低模型參數量的同時提升獲取大感受野的能力,文內附代碼解讀。

引言

Transformer由於其更大的感受野能夠讓其擁有更強的模型表徵能力,性能上超越了很多CNN的模型。

然而單純增大感受野也會帶來其他問題,比如說ViT中大量使用密集的注意力,會導致需要額外的內存和計算代價,特徵很容易被無關的部分所影響。

而PVT或者Swin Transformer中使用的sparse attention是數據不可知的,會影響模型對長距離依賴的建模能力。

由此引入主角:Deformabel Attention Transformer的兩個特點:

  • data-dependent: key和value對的位置上是依賴於數據的。
  • 結合Deformable 方式能夠有效降低計算代價,提升計算效率。

下圖展示了motivation:

圖中比較了幾種方法的感受野,其中紅色星星和藍色星星表示的是不同的query。而實線包裹起來的目標則是對應的query參與處理的區域。

(a) ViT對所有的query都一樣,由於使用的是全局的注意力,所以感受野覆蓋全圖。

(b) Swin Transformer中則使用了基於window劃分的注意力。不同query處理的位置是在一個window內部完成的。

(c) DCN使用的是3x3卷積核基礎上增加一個偏移量,9個位置都學習到偏差。

(d) DAT是本文提出的方法,由於結合ViT和DCN,所有query的響應區域是相同的,但同時這些區域也學習了偏移量。

方法

先回憶一下Deformable Convolution:

簡單來講是使用了額外的一個分支迴歸offset,然後將其加載到座標之上得到合適的目標。

在回憶一下ViT中的Multi-head Self-attention:

\[\begin{aligned} q&=x W_{q}, k=x W_{k}, v=x W_{v}, \\ z^{(m)}&=\sigma\left(q^{(m)} k^{(m) \top} / \sqrt{d}\right) v^{(m)}, m=1, \ldots, M, \\ z&=\text { Concat }\left(z^{(1)}, \ldots, z^{(M)}\right) W_{o}, \\ z_{l}^{\prime} &=\operatorname{MHSA}\left(\operatorname{LN}\left(z_{l-1}\right)\right)+z_{l-1}, \\ z_{l} &=\operatorname{MLP}\left(\operatorname{LN}\left(z_{l}^{\prime}\right)\right)+z_{l}^{\prime}, \end{aligned} \]

有了以上鋪墊,下圖就是本文最核心的模塊Deformable Attention。

  • 左邊這部分使用一組均勻分佈在feature map上的參照點
  • 然後通過offset network學習偏置的值,將offset施加於參照點中。
  • 在得到參照點以後使用bilinear pooling操作將很小一部分特徵圖摳出來,作爲k和v的輸入
x_sampled = F.grid_sample(
input=x.reshape(B * self.n_groups, self.n_group_channels, H, W), 
grid=pos[..., (1, 0)], # y, x -> x, y
mode='bilinear', align_corners=True) # B * g, Cg, Hg, Wg
  • 之後將得到的Q,K,V執行普通的self-attention, 並在其基礎上增加relative position bias offsets。

其中offset network構建很簡單, 代碼和圖示如下:

  self.conv_offset = nn.Sequential(
      nn.Conv2d(self.n_group_channels, self.n_group_channels, kk, stride, kk//2, groups=self.n_group_channels),
      LayerNormProxy(self.n_group_channels),
      nn.GELU(),
      nn.Conv2d(self.n_group_channels, 2, 1, 1, 0, bias=False)
  )

最終網絡結構爲:

具體參數如下:

實驗

實驗配置:300epoch,batch size 1024, lr=1e-3,數據增強大部分follow DEIT

  • 分類結果:

目標檢測數據集結果:

語義分割:

  • 消融實驗:

  • 可視化結果:COCO

這個可視化結果有點意思,如果是分佈在背景上的點大部分變動不是很大,即offset不是很明顯,但是目標附近的點會存在一定的集中趨勢(ps:這種趨勢沒有Deformable Conv中的可視化結果明顯)

代碼

  • 生成Q
  B, C, H, W = x.size()
  dtype, device = x.dtype, x.device
  
  q = self.proj_q(x)
  • offset network前向傳播得到offset
  q_off = einops.rearrange(q, 'b (g c) h w -> (b g) c h w', g=self.n_groups, c=self.n_group_channels)
  offset = self.conv_offset(q_off) # B * g 2 Hg Wg
  Hk, Wk = offset.size(2), offset.size(3)
  n_sample = Hk * Wk
  • 在參照點基礎上使用offset
offset = einops.rearrange(offset, 'b p h w -> b h w p')
reference = self._get_ref_points(Hk, Wk, B, dtype, device)
    
if self.no_off:
    offset = offset.fill(0.0)
    
if self.offset_range_factor >= 0:
    pos = offset + reference
else:
    pos = (offset + reference).tanh()
  • 使用bilinear pooling的方式將對應feature map摳出來,等待作爲k,v的輸入。
x_sampled = F.grid_sample(
    input=x.reshape(B * self.n_groups, self.n_group_channels, H, W), 
    grid=pos[..., (1, 0)], # y, x -> x, y
    mode='bilinear', align_corners=True) # B * g, Cg, Hg, Wg
    
x_sampled = x_sampled.reshape(B, C, 1, n_sample)

q = q.reshape(B * self.n_heads, self.n_head_channels, H * W)
k = self.proj_k(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)
v = self.proj_v(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)
  • 在positional encodding部分引入相對位置的偏置:
  rpe_table = self.rpe_table
  rpe_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
  
  q_grid = self._get_ref_points(H, W, B, dtype, device)
  
  displacement = (q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) - pos.reshape(B * self.n_groups, n_sample, 2).unsqueeze(1)).mul(0.5)
  
  attn_bias = F.grid_sample(
      input=rpe_bias.reshape(B * self.n_groups, self.n_group_heads, 2 * H - 1, 2 * W - 1),
      grid=displacement[..., (1, 0)],
      mode='bilinear', align_corners=True
  ) # B * g, h_g, HW, Ns
  
  attn_bias = attn_bias.reshape(B * self.n_heads, H * W, n_sample)
  
  attn = attn + attn_bias

參考

https://github.com/LeapLabTHU/DAT

https://arxiv.org/pdf/2201.00520.pdf

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章