NMS算法的理解
NonMaximumSuppression 非極大值抑制
當預測網絡預測出bbox的位置之後,一定會產生很多種可能。每一個bbox包括位置信息和置信度(概率),這個時候就需要根據nms的來排除掉一些冗餘的bbox。
例如,人臉檢測算法得到了8個人臉檢測框,這8個檢測框中明顯是由兩個人同時有兩個框的,這樣就產生了冗餘,需要利用nms將這些多餘的框去掉。
代碼:
//人臉檢測結果數據結構bbox
typedef struct FaceRect{
float x1;
float y1;
float x2;
float y2;
float score; /**< Larger score should mean higher confidence. */
} FaceRect;
排序:
// compare score
bool CompareBBox(const FaceInfo & a, const FaceInfo & b) {
return a.bbox.score > b.bbox.score;
}
nms代碼:
std::vector<FaceInfo> NonMaximumSuppression(std::vector<FaceInfo>& bboxes,
float thresh,char methodType){
std::vector<FaceInfo> bboxes_nms;
std::sort(bboxes.begin(), bboxes.end(), CompareBBox);//按照score降序排列
int32_t select_idx = 0;
int32_t num_bbox = static_cast<int32_t>(bboxes.size());
std::vector<int32_t> mask_merged(num_bbox, 0);
bool all_merged = false;
while (!all_merged) {
while (select_idx < num_bbox && mask_merged[select_idx] == 1)
select_idx++;
if (select_idx == num_bbox) {
all_merged = true;
continue;NM
}
bboxes_nms.push_back(bboxes[select_idx]);
mask_merged[select_idx] = 1;
FaceRect select_bbox = bboxes[select_idx].bbox;
float area1 = static_cast<float>((select_bbox.x2-select_bbox.x1+1) * (select_bbox.y2-select_bbox.y1+1));
float x1 = static_cast<float>(select_bbox.x1);
float y1 = static_cast<float>(select_bbox.y1);
float x2 = static_cast<float>(select_bbox.x2);
float y2 = static_cast<float>(select_bbox.y2);
select_idx++;
for (int32_t i = select_idx; i < num_bbox; i++) {
if (mask_merged[i] == 1)
continue;
FaceRect& bbox_i = bboxes[i].bbox;
float x = std::max<float>(x1, static_cast<float>(bbox_i.x1));
float y = std::max<float>(y1, static_cast<float>(bbox_i.y1));
float w = std::min<float>(x2, static_cast<float>(bbox_i.x2)) - x + 1;
float h = std::min<float>(y2, static_cast<float>(bbox_i.y2)) - y + 1;
if (w <= 0 || h <= 0)
continue;
float area2 = static_cast<float>((bbox_i.x2-bbox_i.x1+1) * (bbox_i.y2-bbox_i.y1+1));
float area_intersect = w * h;
switch (methodType) {
case 'u':
if (static_cast<float>(area_intersect) / (area1 + area2 - area_intersect) > thresh)
mask_merged[i] = 1;
break;
case 'm':
if (static_cast<float>(area_intersect) / std::min(area1 , area2) > thresh)
mask_merged[i] = 1;
break;
default:
break;
}
}
}
return bboxes_nms;
}
大體思路
首先對8個bbox進行降序排列,排序的依據就是bbox.score的值,從大到小依次排好。
std::sort(bboxes.begin(), bboxes.end(), CompareBBox);//按照score降序排列
分數最高的肯定是要保留下來的
bboxes_nms.push_back(bboxes[select_idx]);//最初select_idx = 0
然後進行一個 (人臉數 -1)次的循環,依次判斷後續的人臉的和第一個人臉的IOU,如果大於閾值,那麼就將這個人臉pass掉,如果後面的人臉和第一個人臉的IOU爲0那麼就跳過,先不處理。
if (static_cast<float>(area_intersect) / (area1 + area2 - area_intersect) > thresh)
mask_merged[i] = 1;//這裏利用mesk_merged做標記,標記爲1的證明大於閾值,需要排除
依次循環判斷,最後將冗餘的框排除。
但是nms在目標檢測的有些時候會影響到一定的召回率,例如下面這種情況
有一篇論文專門針對nms進行了優化,稱之爲soft-nms。下篇文章再總結下soft-nms