NMS算法的理解

NMS算法的理解

NonMaximumSuppression 非極大值抑制

當預測網絡預測出bbox的位置之後,一定會產生很多種可能。每一個bbox包括位置信息和置信度(概率),這個時候就需要根據nms的來排除掉一些冗餘的bbox。

例如,人臉檢測算法得到了8個人臉檢測框,這8個檢測框中明顯是由兩個人同時有兩個框的,這樣就產生了冗餘,需要利用nms將這些多餘的框去掉。
這裏寫圖片描述

代碼:

//人臉檢測結果數據結構bbox
typedef struct  FaceRect{
  float x1;
  float y1;
  float x2;
  float y2;
  float score; /**< Larger score should mean higher confidence. */
} FaceRect;

排序:

// compare score
bool CompareBBox(const FaceInfo & a, const FaceInfo & b) {
    return a.bbox.score > b.bbox.score;
}

nms代碼:

std::vector<FaceInfo> NonMaximumSuppression(std::vector<FaceInfo>& bboxes,
                                                   float thresh,char methodType){
    std::vector<FaceInfo> bboxes_nms;
    std::sort(bboxes.begin(), bboxes.end(), CompareBBox);//按照score降序排列

    int32_t select_idx = 0;
    int32_t num_bbox = static_cast<int32_t>(bboxes.size());
    std::vector<int32_t> mask_merged(num_bbox, 0);
    bool all_merged = false;

    while (!all_merged) {
        while (select_idx < num_bbox && mask_merged[select_idx] == 1)
            select_idx++;
        if (select_idx == num_bbox) {
            all_merged = true;
            continue;NM
        }

        bboxes_nms.push_back(bboxes[select_idx]);
        mask_merged[select_idx] = 1;

        FaceRect select_bbox = bboxes[select_idx].bbox;
        float area1 = static_cast<float>((select_bbox.x2-select_bbox.x1+1) * (select_bbox.y2-select_bbox.y1+1));
        float x1 = static_cast<float>(select_bbox.x1);
        float y1 = static_cast<float>(select_bbox.y1);
        float x2 = static_cast<float>(select_bbox.x2);
        float y2 = static_cast<float>(select_bbox.y2);

        select_idx++;
        for (int32_t i = select_idx; i < num_bbox; i++) {
            if (mask_merged[i] == 1)
                continue;

            FaceRect& bbox_i = bboxes[i].bbox;
            float x = std::max<float>(x1, static_cast<float>(bbox_i.x1));
            float y = std::max<float>(y1, static_cast<float>(bbox_i.y1));
            float w = std::min<float>(x2, static_cast<float>(bbox_i.x2)) - x + 1;
            float h = std::min<float>(y2, static_cast<float>(bbox_i.y2)) - y + 1;
            if (w <= 0 || h <= 0)
                continue;

            float area2 = static_cast<float>((bbox_i.x2-bbox_i.x1+1) * (bbox_i.y2-bbox_i.y1+1));
            float area_intersect = w * h;

            switch (methodType) {
            case 'u':
                if (static_cast<float>(area_intersect) / (area1 + area2 - area_intersect) > thresh)
                    mask_merged[i] = 1;
                break;
            case 'm':
                if (static_cast<float>(area_intersect) / std::min(area1 , area2) > thresh)
                    mask_merged[i] = 1;
                break;
            default:
                break;
            }
        }
    }
    return bboxes_nms;
}

大體思路

首先對8個bbox進行降序排列,排序的依據就是bbox.score的值,從大到小依次排好。

std::sort(bboxes.begin(), bboxes.end(), CompareBBox);//按照score降序排列

分數最高的肯定是要保留下來的

bboxes_nms.push_back(bboxes[select_idx]);//最初select_idx = 0

然後進行一個 (人臉數 -1)次的循環,依次判斷後續的人臉的和第一個人臉的IOU,如果大於閾值,那麼就將這個人臉pass掉,如果後面的人臉和第一個人臉的IOU爲0那麼就跳過,先不處理。

if (static_cast<float>(area_intersect) / (area1 + area2 - area_intersect) > thresh)
  mask_merged[i] = 1;//這裏利用mesk_merged做標記,標記爲1的證明大於閾值,需要排除

依次循環判斷,最後將冗餘的框排除。

但是nms在目標檢測的有些時候會影響到一定的召回率,例如下面這種情況

這裏寫圖片描述

有一篇論文專門針對nms進行了優化,稱之爲soft-nms。下篇文章再總結下soft-nms

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章