'''
# INPUT:所有預測出的bounding box (bbx)信息(座標和置信度confidence), IOU閾值(大於該閾值的bbx將被移除)
for object in all objects:
(1) 獲取當前目標類別下所有bbx的信息
(2) 將bbx按照confidence從高到低排序,並記錄當前confidence最大的bbx
(3) 計算最大confidence對應的bbx與剩下所有的bbx的IOU,移除所有大於IOU閾值的bbx
(4) 對剩下的bbx,循環執行(2)和(3)直到所有的bbx均滿足要求(即不能再移除bbx)
'''
import numpy as np
def non_max_suppress(predicts_dict, threshhold=0.2):
'''
:param predicts_dict: {'分類1':[[Xmin, Ymin, Xmax, Ymax, Score], [...]], '分類2':[[...]]}
:param threshhold: suprress threshhold
:return:
'''
for object_name, bbox in predicts_dict.items():
# list to array
bbox_array = np.array(bbox, dtype=np.float)
# get coordinates
Xmin, Ymin, Xmax, Ymax, scores = bbox_array[:, 0], bbox_array[:, 1], bbox_array[:, 2], bbox_array[:, 3], bbox_array[:, 4],
# 獲得每個bbox的面積,用於計算並集
bbox_area = (Xmax - Xmin) * (Ymax - Ymin)
# 按score降序排序的bbox索引
order = scores.argsort()[::-1]
# 最終保留的bbox索引
keep = []
while order.size > 0:
# 第一個爲該類別最大置信度的索引,保留
i = order[0]
keep.append(i)
# 計算與其他bbox的IOU
# 計算左上角和右下角座標
inter_Xmin = np.maximum(Xmin[i], Xmin[order[1:]])
inter_Ymin = np.maximum(Ymin[i], Ymin[order[1:]])
inter_Xmax = np.minimum(Xmax[i], Xmax[order[1:]])
inter_Ymax = np.minimum(Ymax[i], Ymax[order[1:]])
# 計算交集
inter_area = np.maximum(0, inter_Xmax-inter_Xmin) * np.maximum(0, inter_Ymax - inter_Ymin)
# 計算IOU
IOU = inter_area / (bbox_area[i] + bbox_area[order[1:]] - inter_area + 1e-6)
# 獲取保留下來的索引(因爲沒有計算與自身的IOU,所以索引相差1,需要加上)
indexs = np.where(IOU <= threshhold)[0] + 1
order = order[indexs]
bbox = bbox_array[keep]
predicts_dict[object_name] = bbox.tolist()
return predicts_dict
predicts_dict = non_max_suppress({'蘋果':[[0, 0, 1, 1, 0.6], [2, 2, 4, 4, 0.8], [3, 3, 4, 4, 0.5]]})
print(predicts_dict)