End-to-End Object Detection with Transformers[DETR]

背景

最近在做機器翻譯的優化,接觸的模型就是transformer, 爲了提升性能,在cpu和GPU兩個平臺c++重新寫了整個模型,所以對於機器翻譯中transformer的原理細節還是有一定的理解,同時以前做文檔圖片檢索對於圖像領域的目標檢測也研究頗深,看到最近各大公衆號都在推送這篇文章就簡單的看了一下,感覺還是蠻有新意的,由於該論文開源,所以直接就跟着代碼來解讀整篇論文。

概述

在這裏插入圖片描述
整體來看,該模型首先是經歷一個CNN提取特徵,然後得到的特徵進入transformer, 最後將transformer輸出的結果轉化爲class和box.

 def forward(self, samples):
		"""
		這一段代碼時從源碼detr.py的DETR中抽出來的代碼,爲了邏輯清爽,刪除了一些
		細枝末節的內容,核心邏輯如下
		"""
		#backbone模型中核心就是圖中的CNN模型,可以自己選擇resnet,vgg什麼的,features就是卷積後的輸出
        features, pos = self.backbone(samples)#sample 就是圖片,大小比如(3,200,250)
        src, mask = features[-1].decompose()
        #transformer模型處理一波
        hs = transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]
        #transformer模型的最終結果爲hs,將其分別進入class和box的模型中處理得到class和box
        outputs_class = class_embed(hs)
        outputs_coord = bbox_embed(hs).sigmoid()
        out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
        return out

下面是大致的推理過程:
在這裏插入圖片描述

相關技術

輸入

作者這裏封裝了一個類,感覺多此一舉,假如我們輸入的是如下兩張圖片,也就說batch爲2:
img1 = torch.rand(3, 200, 200),
img2 = torch.rand(3, 200, 250)

x = nested_tensor_from_tensor_list([torch.rand(3, 200, 200), torch.rand(3, 200, 250)])

這裏會轉成nested_tensor, 這個nestd_tensor是什麼類型呢?簡單說就是把{tensor, mask}打包在一起, tensor就是我麼的圖片的值,那麼mask是什麼呢? 當一個batch中的圖片大小不一樣的時候,我們要把它們處理的整齊,簡單說就是把圖片都padding成最大的尺寸,padding的方式就是補零,那麼batch中的每一張圖都有一個mask矩陣,所以mask大小爲[2, 200,250], tensor大小爲[2,3,200,250]。

提取特徵

接下里就是把tensor, 也就是圖片輸入到特徵提取器中,這裏作者使用的是殘差網絡,我做實驗的時候用多個resnet-50, 所以tensor經過resnet-50後的結果就是[2,2048,7,8],下面是殘差網絡最後一層的結構。

(2): Bottleneck(
(conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d()
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d()
(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d()
(relu): ReLU(inplace=True)

別忘了,我們還有個mask, mask採用的方式F.interpolate,最後得到的結果是[2,7,8]

獲取position_embedding

這裏作者使用的三角函數的方式獲取position_embediing, 如果你對位置編碼不瞭解,你可以這樣理解,“我愛祖國”,“我”位於第一位,如果編碼後不加入位置信息,那麼“我”這個字的編碼信息就是不完善的,所以這裏也一樣,下面是源碼,有興趣的可以推導一下,position_embediing的輸入是上面的NestedTensor={tensor,mask}, 輸出最終pos的size爲[1,2,256,7,8]。

def forward(self, tensor_list: NestedTensor):
        x = tensor_list.tensors
        mask = tensor_list.mask
        assert mask is not None
        not_mask = ~mask
        y_embed = not_mask.cumsum(1, dtype=torch.float32)
        x_embed = not_mask.cumsum(2, dtype=torch.float32)
        if self.normalize:
            eps = 1e-6
            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t
        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        return pos

transformer

transformer分爲編碼和解碼,下面分別介紹:

encoder

經過上面一系列操作以後,目前我們擁有src=[ 2, 2048,7,8],mask=[2,7,8], pos=[1,2,256,7,8]

hs = transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]#

input_proj:一個卷積層,卷積核爲1*1,說白了就是將壓縮通道的作用,將2048壓縮到256,所以傳入transformer的維度是壓縮後的[2,256,7,8]。
self.query_embed.weight:現在還用不到,在decoder的時候用的到,到時候再說。
來看一下transformer

class Transformer(nn.Module):

    def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
                 num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=False,
                 return_intermediate_dec=False):
        super().__init__()
		# encode
		# 單層
        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
                                                dropout, activation, normalize_before)
        encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
        # 由6個單層組成整個encoder
        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
		#decode
        decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
                                                dropout, activation, normalize_before)
        decoder_norm = nn.LayerNorm(d_model)
        self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
                                          return_intermediate=return_intermediate_dec)

爲了更清楚看到具體模型結構
在這裏插入圖片描述
根據代碼和模型結構可以看到,encoder部分就是6個TransformerEncodeLayer組成,而每一個編碼層又由1個self_attention, 2個ffn,2個norm。
在進行encoder之前先還有個處理:

bs, c, h, w = src.shape# 這個和我們上面說的一樣[2,256,7,8]
src = src.flatten(2).permute(2, 0, 1) # src轉爲[56,2,256]
pos_embed = pos_embed.flatten(2).permute(2, 0, 1)# pos_embed 轉爲[56,2,256]
mask = mask.flatten(1) #mask 轉爲[2,56]

encoder的輸入爲:src, mask, pos_embed,接下來捋一捋第一個單層encoder的過程

 q = k = self.with_pos_embed(src, pos)# pos + src
 src2 = self.self_attn(q, k, value=src, key_padding_mask=mask)[0]
 #做self_attention,這個不懂的需要補一下transfomer的知識
 src = src + self.dropout1(src2)# 類似於殘差網絡的加法
 src = self.norm1(src)# norm,這個不是batchnorm,很簡單不在詳述
 src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))#兩個ffn
 src = src + self.dropout2(src2)# 同上殘差加法
 src = self.norm2(src)# norm
 return src

根據模型的代碼可以看到單層的輸出依然爲src[56, 2, 256],第二個單層的輸入依然是:src, mask, pos_embed。循環往復6次結束encoder,得到輸出memory, memory的size依然爲[56, 2, 256].

decoder

encoder結束後我們來看decoder, 先看代碼:

tgt = torch.zeros_like(query_embed)
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
                  pos=pos_embed, query_pos=query_embed)
                   

現在來找輸入:

  1. memory:這個就是encoder的輸出,size爲[56,2,256]
  2. mask:還是上面的mask
  3. pos_embed:還是上面的pos_embed
  4. query_embed:?
  5. tgt: 每一層的decoder的輸入,第一層的話等於0

所以目前我們只要知道query_embed就行了,這個query_embed其實是一個varible,size=[100,2,256],由訓練得到,結束後就固定下來了。到目前爲止我們獲得了decoder的所有輸入,和encoder一樣我們先來看看單層的decoder的運行流程:

如果你不知道100是啥,那你多少需要看一眼論文,這個100表示將要預測100個目標框,你問爲什麼是100框,因爲作者用的數據集的目標種類有90個,萬一一個圖上有90個目標你至少都能檢測出來吧,所以100個框合理。此外這裏和語言模型的輸入有很大區別,比如翻譯時自迴歸,也就是說翻譯出一個字,然後把這個字作爲下一個解碼的輸入(這裏看不懂的可以去看我博客裏將transformer的那一篇),作者這裏直接用[100, 256]作爲輸入感覺也是蠻厲害的。

 q = k = self.with_pos_embed(tgt, query_pos)# tgt + query_pos, 第一層的tgt爲0
 tgt2 = self.self_attn(q, k, value=tgt, key_padding_mask=mask)[0]# 同上
 tgt = tgt + self.dropout1(tgt2)
 tgt = self.norm1(tgt)
 tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
                            key=self.with_pos_embed(memory, pos),
                            value=memory, 
                            key_padding_mask=mask)[0]#交叉attention
 tgt = tgt + self.dropout2(tgt2)
 tgt = self.norm2(tgt)
 tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
 tgt = tgt + self.dropout3(tgt2)
 tgt = self.norm3(tgt)
 return tgt

這裏的難點可能是交叉attention,也叫encoder_decoder_attention, 這裏利用的是encoder的輸出來參與計算,裏面的計算細節同樣可以參考這裏,經過上面六次的處理,最後得到的結果爲[100,2,256], 返回的時候做一個轉換,最終的結果transpose(1, 2)->[100,256,2]。

迴歸

class MLP(nn.Module):
    """ Very simple multi-layer perceptron (also called FFN)"""

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x
        
 self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
 self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
 
 outputs_class = self.class_embed(hs)
 outputs_coord = self.bbox_embed(hs).sigmoid()
 out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}

這幾行代碼就不解釋了,至於爲什麼是output_calss[-1], 作爲思考題留給大家,如果整個源碼擼一遍的話就會知道原因,總的來說最後迴歸的邏輯比較簡單清晰,下面是最後的結果:
pred_logits:[2,100,92]
outputs_coord:[2,100,4]

總結

以上就是整個DETR的推理過程,在訓練的時候還涉及到100個框對齊的問題,也不難這裏就不再講述了,如果想徹底理解整個模型,你需要對卷積,attention有比較深刻的理解,不然即使看懂了流程也不明白爲什麼這樣做,該論文的坑位目測還不少,而且對於目標檢測的模型來說這個代碼量算是少的了,給起來也快,需要畢業的孩紙抓緊啦,哈哈哈

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章