SSD(Single Shot MultiBox Detector)中的框迴歸詳解(真的詳細!!!)

很多人都知道Anchor-based的目標檢測網絡按照階段來分有兩類,以RCNN爲代表的的二階段檢測和SSD,Yolo爲代表的一階段網絡。其中二階段網絡的RPN和一階段網絡的輸出,損失函數都有框迴歸的部分。本篇博客就詳細解釋框迴歸的原理,步驟和Smoth L1 loss.

前言

在正式講解之前,我們需要知道,網絡的輸出不是框的座標和寬,高。這些座標,寬高需要從輸出向量解碼得到。我們都知道,目標檢測網絡box分支輸出的是離GT的偏置(offset)。所以box分支的label,其實並不是座標值,而是要把座標值先按照一定格式要求編碼爲一個四維向量,四個值分別描述了框中心點xy離GT框中心點的偏置,以及寬和GT寬的尺度變化信息,以及高和GT的高的尺度變化信息。

框迴歸是啥

anchor-base目標檢測方法是有anchor的。爲簡單描述起見,以下提到的anchor統統都在原圖空間下。
如下圖所示,綠色框是網絡預定義的一個anchor,這個anchor對應的score非常高,所以網絡會把這個anchor挑出來作爲框的輸出,但是anchor這個東西是死的,即預先按照位置比例以及長寬比以及anchor size提前預定的,所以anchor一般來說都是沒法準確框到目標的,我們需要在anchor的基礎上進行框迴歸,或者認爲是調整。
一個框的信息用四個值能描述,框中心的xy座標,以及框的寬高,這四個值。對於中心點的座標的調整,其實就是偏移,數學運算是加和減;而對於寬和高的調整,數學運算是乘和除。因此在後面的部分我們可以看出anchor是如何迴歸(解碼)得到最後的框位置。
在這裏插入圖片描述

訓練過程中,編碼GT

假設GT是xc,yc,w,hx_c,y_c,w,h,意義分別是GT的中心座標和寬高。但這四個值是不能直接參與訓練的。我們要先把它轉化(或稱編碼)爲相對於anchor的偏置量,偏移量參加訓練,計算loss。那麼具體如何編碼GT呢?
首先對於一個特徵圖FRH×WF \in R^{H \times W}, 存在H×W×kH \times W \times k anchor, k是按照不同長寬比在同一個位置設立k個anchor的意思。
然後從這麼多anchor中,計算IOU,選擇和GT的交併比大於一個閾值的anchor,這些anchor被認爲是正樣本,是要參與計算框迴歸的。另外,這些anchor也要參與分類任務,但本文只涉及框迴歸這一部分,事實上分類損失也很好理解。
從IOU過濾之後,保留下的anchor中,爲了通用性的解釋,我下面就以一個anchor爲例子,這個anchor是滿足上述條件的。這個anchor就看做是上面圖例的綠色框吧,簡單起見,記作AA
AA是四維向量,值記作(axc,ayc,aw,ah)( ax_c, ay_c, aw, ah),意義分別是anchor的中心xy座標,anchor的寬和高。這樣就可以開始計算針對於這個anchor,對應的計算loss需要使用的真值偏置向量了。
dx=(xcaxc)/awd_x = (x_c - ax_c) / aw
dy=(ycayc)/ahd_y = (y_c - ay_c) / ah
dw=log(w/aw)d_w = log(w / aw)
dh=log(h/ah)d_h = log(h/ah)
gA=(dx,dy,dw,dh)g_A = (d_x, d_y, d_w, d_h)

網絡最後的直接輸出(我用直接形容,意思是最後的卷積輸出,沒有其他後處理)就是anchor的偏置項,記作pAp_A,意思是anchor AA 的預測偏置項。而公式最下面的gAg_A,意思是AA的真值GT。
然後A的框迴歸的loss就是:
lossregA=smoothL1(pA,gA)loss^A_{reg} = smooth L1 (p_A, g_A)

另外我還沒提到的是,所有的值都是被歸一化的。比如xc,yc,w,hx_c,y_c,w,h 這四個值其實是需要處理一下,橫座標和寬需要除以圖像的寬,縱座標和高,需要除以圖像的高,來做歸一化。至於(axc,ayc,aw,ah)( ax_c, ay_c, aw, ah),則對應除以特徵圖的寬和高

測試過程中,偏置項的解碼

訓練結束之後,我們要測試,需要輸出的不是偏置項,而是實實在在的框的位置。解碼方式就是編碼的反轉。其中(axc,ayc,aw,ah)( ax_c, ay_c, aw, ah)是已知的,因爲是預定義的。dx,dy,dw,dhd_x, d_y, d_w, d_h都換成是網絡的直接輸出值。這樣就可以反過來去求xc,yc,w,hx_c,y_c,w,h ,這四個值就是網絡的框的位置輸出(間接輸出,因爲由解碼得到)。

從代碼來看編碼解碼

先看編碼

	# matched 是框的GT,順序是x,y, w,h ,下面一句通過先求框的中心點,然後框的中心點減去anchor的中心點得到差值
   g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - priors[:, :2]
  # 差值除以anchor的寬和高,分別得到上述前兩個公式dx, dy
        g_cxcy /= priors[:, 2:]
        # 先求GT的寬和高與 anchor的寬和高的比值
        g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
        # 然後求log,也是一一與上述步驟對應。
        g_wh = torch.log(g_wh)
        # 
        loc = torch.cat([g_cxcy, g_wh], 1)  # [num_priors,4]

再看解碼

    boxes = torch.cat((
    		# anchor 的x y座標加上 預測的xy的偏置項 乘 anchor的寬和高,就是前兩行公式的反推。
            priors[:, :2] + loc[:, :2] * priors[:, 2:],
            priors[:, 2:] * torch.exp(loc[:, 2:] )), 1)
        # 框左上角座標
        boxes[:, :2] -= boxes[:, 2:] / 2
        # 框右下角座標
        boxes[:, 2:] += boxes[:, :2]

結束語

到此爲止,我的講解結束了,如果覺得不錯的,請點個贊或者留言鼓勵哦。優質博客,歡迎轉載,轉載請註明出處。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章