NMS部分主要用在DetectionOutput層中,用於對預測得到的分數和boxes處理
在此層的Forward中調用,如下:
ApplyNMSFast(bboxes, scores, confidence_threshold_, nms_threshold_, eta_,
top_k_, &(indices[c]));
此ApplyNMSFast方法是在caffe/util/bbox_util.cpp中實現,如下:
void ApplyNMSFast(const vector<NormalizedBBox>& bboxes,
const vector<float>& scores, const float score_threshold,
const float nms_threshold, const float eta, const int top_k,
vector<int>* indices) {
// Sanity check.
CHECK_EQ(bboxes.size(), scores.size())
<< "bboxes and scores have different size.";
// Get top_k scores (with corresponding indices).
vector<pair<float, int> > score_index_vec;
//GetMaxScoreIndex用於對預測得到的scores進行排序,並排除低於score_threshold的那些得分及索引,若top_k不等於-1的話要在最後結果中取分數最高的top_k個
GetMaxScoreIndex(scores, score_threshold, top_k, &score_index_vec);
// Do nms.
float adaptive_threshold = nms_threshold;
indices->clear();
while (score_index_vec.size() != 0) {
const int idx = score_index_vec.front().second;
bool keep = true;
for (int k = 0; k < indices->size(); ++k) {
if (keep) {
const int kept_idx = (*indices)[k];
//JaccardOverlap此部分是計算兩個box的交併比(Iou),然後跟nms閾值比較,若此box的與所有已保留的box的交併比均小於閾值則保留此box,否則捨棄
float overlap = JaccardOverlap(bboxes[idx], bboxes[kept_idx]);
keep = overlap <= adaptive_threshold;
} else {
break;
}
}
if (keep) {
indices->push_back(idx);
}
score_index_vec.erase(score_index_vec.begin());
if (keep && eta < 1 && adaptive_threshold > 0.5) {
adaptive_threshold *= eta;
}
}
}
其中計算NMS用到的GetMaxScoreIndex和JaccardOverlap方法均在caffe/util/bbox_util.cpp中實現,如下:
void GetMaxScoreIndex(const vector<float>& scores, const float threshold,
const int top_k, vector<pair<float, int> >* score_index_vec) {
// Generate index score pairs.
for (int i = 0; i < scores.size(); ++i) {
if (scores[i] > threshold) {
score_index_vec->push_back(std::make_pair(scores[i], i));
}
}
// Sort the score pair according to the scores in descending order
std::stable_sort(score_index_vec->begin(), score_index_vec->end(),
SortScorePairDescend<int>);
// Keep top_k scores if needed.
if (top_k > -1 && top_k < score_index_vec->size()) {
score_index_vec->resize(top_k);
}
}
template <typename Dtype>
Dtype JaccardOverlap(const Dtype* bbox1, const Dtype* bbox2) {
if (bbox2[0] > bbox1[2] || bbox2[2] < bbox1[0] ||
bbox2[1] > bbox1[3] || bbox2[3] < bbox1[1]) {
return Dtype(0.);
} else {
const Dtype inter_xmin = std::max(bbox1[0], bbox2[0]);
const Dtype inter_ymin = std::max(bbox1[1], bbox2[1]);
const Dtype inter_xmax = std::min(bbox1[2], bbox2[2]);
const Dtype inter_ymax = std::min(bbox1[3], bbox2[3]);
const Dtype inter_width = inter_xmax - inter_xmin;
const Dtype inter_height = inter_ymax - inter_ymin;
const Dtype inter_size = inter_width * inter_height;
const Dtype bbox1_size = BBoxSize(bbox1);
const Dtype bbox2_size = BBoxSize(bbox2);
return inter_size / (bbox1_size + bbox2_size - inter_size);
}
}