目標檢測 SSD: Single Shot MultiBox Detector - Hard Negative Mining

目標檢測 SSD: Single Shot MultiBox Detector - Hard Negative Mining

flyfish

首先介紹Negative

這還要從伊索寓言:狼來了(精簡版)說起

有一位牧童要照看鎮上的羊羣,但是他開始厭煩這份工作。爲了找點樂子,他大喊道:“狼來了!”其實根本一頭狼也沒有出現。村民們迅速跑來保護羊羣,但他們發現這個牧童是在開玩笑後非常生氣。

[這樣的情形重複出現了很多次。]

一天晚上,牧童看到真的有一頭狼靠近羊羣,他大聲喊道:“狼來了!”村民們不想再被他捉弄,都待在家裏不出來。這頭飢餓的狼對羊羣大開殺戒,美美飽餐了一頓。這下子,整個鎮子都揭不開鍋了。恐慌也隨之而來.

我們做出以下定義:

“狼來了”是正類別( positive class)。
“沒有狼”是負類別( negative class)。

我們可以使用一個 2x2 混淆矩陣來總結我們的“狼預測”模型,該矩陣描述了所有可能出現的結果(共四種):
在這裏插入圖片描述真正例 (TP,True Positive)
假正例 (FP,False Positive)
假負例 (FN,False Negative)
真負例 (TN,True Negative)
True Positive是指模型將正類別樣本正確地預測爲正類別。
True Negative是指模型將負類別樣本正確地預測爲負類別。
False Positive是指模型將負類別樣本錯誤地預測爲正類別。
False Negative是指模型將正類別樣本錯誤地預測爲負類別。

問題是什麼

我們手工設計算法生成了一堆prior boxs,因爲不是所有的prior boxs正好與ground truth正好重合,有完全不包含目標的,也有部分含有目標的.大部分的prior boxs是不包含目標的,不包含目標屬於背景類,就是negative class, IoU 超過一定閾值(例如0.5)才認爲是positive class,樣本的正負類極度的不平衡. 往極端的想,假設一共100個數,99個1,和1個0,模型每次判斷是1就行,因爲即使這樣模型也是能夠判斷99%的正確.

解決方案

我們把我們的prior boxs分個類分爲positive prior和 negative prior,negative prior太多就會造成negative class太多,這樣根據high confidence就可以減少negative class,
high confidence纔算negative class,low confidence不算negative class

擁有high confidence的negative加入數據集
擁有low confidence的negative忽略它,扔了

原來是 所有樣本=正樣本+負樣本
現在是 所有樣本=正樣本+擁有high confidence的負樣本,而且數量比例爲1:3(正樣本是1,負樣本是3)
這樣結果False Negative 就會少.相當於把非背景類識別成了背景類的這種情況就少了

代碼

在看代碼前先看一個函數的用法

import torch
a = torch.randn(4, 3)
print(a)
print("0:",torch.max(a, 0))
print("1:",torch.max(a, 1))


tensor([[ 0.2292, -1.0423,  0.4708],
        [-0.5750,  0.7802,  0.1230],
        [-0.0415, -0.3830, -0.3594],
        [ 0.7926, -0.6060, -0.7392]])

dim=0: 
torch.return_types.max(
values=tensor([0.7926, 0.7802, 0.4708]),
indices=tensor([3, 1, 0]))

dim=1: 
torch.return_types.max(
values=tensor([ 0.4708,  0.7802, -0.0415,  0.7926]),
indices=tensor([2, 1, 0, 0]))

4行3列,dim=0按列找最大的,dim=1是按行找最大的
#VGG版的priors有8732個,我們設計的有2278個

# predicted_locs:   torch.Size([32, 2278, 4])
# predicted_scores: torch.Size([32, 2278, 21])
# self.priors_cxcy   torch.Size([2278, 4])        
class MultiBoxLoss(nn.Module):
    def __init__(self, priors_cxcy, threshold=0.5, neg_pos_ratio=3, alpha=1.):
        super(MultiBoxLoss, self).__init__()
        self.priors_cxcy = priors_cxcy
        self.priors_xy = cxcy_to_xy(priors_cxcy)
        self.threshold = threshold
        self.neg_pos_ratio = neg_pos_ratio
        self.alpha = alpha
        self.smooth_l1 = nn.L1Loss()
        self.cross_entropy = nn.CrossEntropyLoss(reduce=False)
    def forward(self, predicted_locs, predicted_scores, boxes, labels):

        batch_size = predicted_locs.size(0)#這個是根據我們的配置輸出,用字母N替代,這裏假設是32
        n_priors = self.priors_cxcy.size(0)#2278 ,prior的個數
        n_classes = predicted_scores.size(2)#21 這個是根據我們數據集一共有多少類

        

        assert n_priors == predicted_locs.size(1) == predicted_scores.size(1)#必須全部是2278,如果不相同說明,設計的網絡有問題
        true_locs = torch.zeros((batch_size, n_priors, 4), dtype=torch.float).to(device)  #N,2278,4
        true_classes = torch.zeros((batch_size, n_priors), dtype=torch.long).to(device)  #N,2278


        for i in range(batch_size):
            n_objects = boxes[i].size(0)#n_objects可以是任意個數的目標
 
            overlap = find_jaccard_overlap(boxes[i],
                                           self.priors_xy)  #輸出的維度是(n_objects, 2278)
                                   
            #對於每個 prior,,查找重疊最大的目標,
            overlap_for_each_prior, object_for_each_prior = overlap.max(dim=0)  #2278
            # 輸出結果先是values,再是indices
            #overlap_for_each_prior和object_for_each_prior的shape都是 [2278]
                       
            #找出每個目標具有最大重疊的prior       
            _, prior_for_each_object = overlap.max(dim=1)  # 輸出兩個shape是n_objects
            #將每個對象分配給相應的最大重疊prior
            object_for_each_prior[prior_for_each_object] = torch.LongTensor(range(n_objects)).to(device)
            #爲確保這些prior合格,人爲地將它們的重疊部分設置爲大於0.5
            overlap_for_each_prior[prior_for_each_object] = 1.
            #每個prior的標籤
            label_for_each_prior = labels[i][object_for_each_prior]  
            #將與目標重疊小於閾值的prior 設置爲背景類
            label_for_each_prior[overlap_for_each_prior < self.threshold] = 0  
            
            true_classes[i] = label_for_each_prior
            
            true_locs[i] = cxcy_to_gcxgcy(xy_to_cxcy(boxes[i][object_for_each_prior]), self.priors_cxcy)  
        
        #確定prior是positive,目標不是背景類
        positive_priors = true_classes != 0  
 
        #一共兩個loss,confidence loss和localization loss
        #Localization loss 的計算僅僅是 positive  priors (非背景)
        loc_loss = self.smooth_l1(predicted_locs[positive_priors], true_locs[positive_priors])  
        
        #confidence loss
        #hard negative mining來了
        n_positives = positive_priors.sum(dim=1)  
        n_hard_negatives = self.neg_pos_ratio * n_positives  
        #找出所有的prior的loss
        conf_loss_all = self.cross_entropy(predicted_scores.view(-1, n_classes), true_classes.view(-1))  
        conf_loss_all = conf_loss_all.view(batch_size, n_priors)  
        #我們已經知道哪個prior是positive
        conf_loss_pos = conf_loss_all[positive_priors]  
        
        #我們要找出哪個prior是 hard negative

        # conf_loss_neg torch.Size([N, 2278])
        # hardness_ranks torch.Size([N, 2278])
        # n_hard_negatives torch.Size([N])

        #positive_priors([N, 2278])
        conf_loss_neg = conf_loss_all.clone()  #([N, 2278])
        conf_loss_neg[positive_priors] = 0.  
        print(conf_loss_neg[positive_priors],conf_loss_neg)

     
        conf_loss_neg, _ = conf_loss_neg.sort(dim=1, descending=True)  #降序
        hardness_ranks = torch.LongTensor(range(n_priors)).unsqueeze(0).expand_as(conf_loss_neg).to(device)   #hardness_ranks torch.Size([N, 2278])
        hard_negatives = hardness_ranks < n_hard_negatives.unsqueeze(1) 

        # hardness_ranks torch.Size([N, 2278])
        # n_hard_negatives torch.Size([N])

        conf_loss_hard_neg = conf_loss_neg[hard_negatives]  
        #這裏按照論文中敘述,僅對positive priors進行平均,儘管同時計算 positive prior和 hard-negative prior。
        conf_loss = (conf_loss_hard_neg.sum() + conf_loss_pos.sum()) / n_positives.sum().float()  
        
        # self.alpha是權重,論文裏是1
        return conf_loss + self.alpha * loc_loss

參考

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章