論文鏈接:RetinaNet
摘要
目前目標檢測任務中精度最高的模型是基於主流的R-CNN框架的二階段模型,該類方法在一些列目標候選框上進行分類。相對的,一階段模型直接在大量的可能包含目標的區域進行檢測,這樣做速度更快但相比於兩階段模型也犧牲了精度,我們在本文工作中分析了這個問題的原因。我們發現訓練過程中正負樣本(指前景和背景)之間嚴重的不平衡是主要原因。我們通過修改標準的交叉信息熵損失函數來解決類別之間的不平衡,使得那些被很好地分類的樣本的權重降低。我們提出的Focal Loss在訓練中更關注那些難分類的樣本,抑制了那些易分類的負樣本的主導作用。爲了評估Focal Loss的作用,我們設計並訓練了一個簡單的檢測器即RetinaNet。我們的實驗結果表明當使用Focal Loss進行訓練時,RetinaNet可以在達到當前一階段檢測模型速度的同時超過現有的所有排名靠前的二階段模型。源碼地址: https://github.com/facebookresearch/Detectron.(目前Detectron已經更新到Detectron2)
動機
作者在文中指出,一階段模型雖然在速度上比二階段模型快很多,但在精度上卻也不如二階段模型。而導致這個問題的最根本原因是一階段模型的正負樣本不平衡要比二階段模型嚴重德多,這樣就會導致模型在訓練過程中背景佔據主導作用。那爲什麼一階段模型的樣本不平衡要比二階段模型嚴重呢?這取決於兩者候選框的生成機制:
- 二階段模型使用諸如RPN的網絡來生成候選框,在生成候選框的時候控制了候選框的數量,並且過濾掉了大部分的背景框。在訓練分類器的時候,又嚴格控制了正負樣本的比例(如典型的1:3)。
- 一階段模型中並沒有像RPN網絡這樣的機制,只能在原圖上生成大量的候選框(高達數萬個),其中大多數是背景框,導致候選框中的正負樣本比例嚴重不平衡。雖然可以採用一些重採樣的策略(如困難樣本挖掘),但背景框依然佔據着絕對的主導作用,導致模型在訓練過程過擬合問題嚴重。
基於上述分析,文章從檢測模型的類別損失函數入手,基於經典的交叉信息熵損失進行改進,使得模型在訓練過程中自動權衡多數簡單樣本和少數困難樣本的權重,從而解決一階段模型中樣本失衡的問題。
主要工作
Focal Loss
Focal Loss是對典型的交叉信息熵損失函數的改進。對於一個二分類問題,交叉信息熵損失函數定義如下:
爲了同一正負樣本的損失函數表達式,做如下定義:
在形式上就表示被預測爲對應的正確類別的置信度。這樣二分類交叉信息熵損失就可以重寫成如下形式:
爲了平衡多數類和少數類的損失,一種常規的思想就是在損失項前乘上一個平衡係數,當類別爲正時,取,當類別爲負時,取,這樣得到的帶有平衡係數的交叉信息熵損失定義如下:
這樣,根據訓練樣本中正負樣本數量來選取的值,就可以達到平衡正負樣本的作用。然而,這樣做還不能對簡單和困難樣本區別對待,在目標檢測中,既要平衡多數類(背景)和少數類(包含目標的前景),還要平衡簡單樣本和困難樣本,而往往訓練過程中往往遇到的問題就是大量簡單的背景樣本佔據損失函數的主要部分。因此,還需要對上述帶有平衡係數的交叉信息熵損失做進一步的改進。於是就有了文章的核心內容Focal Loss,它定義如下:
相比於上面的加了平衡係數的損失函數相比,Focal Loss有以下兩點不同:
- 固定的平衡係數替換成了可變的平衡係數
- 多了另外一個調節因子,且
之前已經提到過,在形式上表示樣本被預測爲對應的正確類別的置信度,因此,可以理解爲越接近於1,樣本被正確分類的概率越高,也就意味着這個樣本越容易分類,即簡單樣本。當越接近於0,表明樣本被錯誤分類的概率越大,意味着這個樣本越難分類,即困難樣本。這時候再看前面的係數,在越接近於1時,該項的值越小,反之越大,而由於始終大於0,它對的結果做了進一步的放縮。舉個例子,當時,假設有一個正樣本預測正確的置信度,可以說是一個容易分類的樣本,此時平衡的係數項爲,意味它產生的損失變爲原來的1/100;相反地,若一個樣本被正確分類的置信度,說明這是一個難分類的樣本,此時它對應的平衡係數的結果爲,衰減程度遠小於置信度爲0.9的樣本。論文也對比了在取不同的值時的loss變化,如下圖:
可以發現,取值越大,分類結果好的樣本對應的縮放幅度越大。而在實驗中文章也指出在的時候結果最好。
RetinaNet框架
模型結構。 RetinaNet整體上是一個一階段模型,由一個主幹網絡和兩個分支網絡組成。主幹網絡在ResNet結構採用特徵金字塔,通過引入橫向的連接融合不同層次的特徵圖;兩個分支網絡採用相同結構的全卷積網絡(參數不共享),第一個分支負責預測類別信息;第二個分支網絡用於邊界框迴歸。從整體上看,RetinaNet是ResNet+FPN+FCN的組合使用。
錨點機制。 RetinaNet中也採用了類似RPN中的錨點機制,在每個特徵金字塔層上都使用了3種長寬比的錨點,每個長寬比的錨點又有3個不同的尺度,共9個錨點。
推理過程。 當一張圖片輸入到網絡進行前向傳播時,在每個特徵金字塔層級上會預測很多候選框,爲了提高推理速度,RetinaNet種只取每層輸出的前1000個置信度最高的候選框,最後將所有層級得到的候選框放到一起進行非極大值抑制(閾值取0.5),得到最終的檢測結果。
實驗結果
從實驗對比中可以看出,基於ResNet-101-FPN和ResNeXt-101-FPN的RetinaNet幾乎在COCO目標檢測的各個指標上都明顯高於主流的二階段模型和一階段模型,而在檢測速度上,RetinaNet也與SSD框架相當。
總結
文章最主要的貢獻是提出了Focal Loss來處理模型訓練過程中少數難分類樣本和多數簡單樣本的不平衡問題。爲了驗證Focal Loss的有效性,文中還設計了一階段檢測網絡RetinaNet,並對比了和主流一階段模型以及二階段模型的性能和速度。結果也表明RetinaNet能夠以一階段模型的檢測速度達到二階段模型的檢測精度。