【深度學習】python實現NMS (非極大值抑制)

【深度學習】python實現NMS (非極大值抑制)

python實現NMS

首先需要明確一點,在多類別的目標檢測任務中,NMS是發生在同一類別中的

NMS的流程:

# 以YOLO系列的目標檢測算法爲例, 網絡輸出tensor結構如下:
# [center_x, center_y, width, height, confidence, class1_score, class2_score, class3_score...]

# 假設一共有三個類別, 則輸入tensor結構如下:
# boxes = np.array([center_x, center_y, width, height, confidence, class1_score, class2_score, class3_score],
#					[center_x, center_y, width, height, confidence, class1_score, class2_score, class3_score], 
# 					...
# 					[center_x, center_y, width, height, confidence, class1_score, class2_score, class3_score])

# 用於保留nms的結果box
result = []

# 對每個類別的輸出分開處理
for each in range(類別數量):
	# 得到當前類別的所有輸出
	the_boxes = boxes[np.where(boxes[:, 5:8].argsort()[:,-1] == each)[0].tolist(), :]

	center_x = the_boxes[:, 0]
	center_y = the_boxes[:, 1]
	width = the_boxes[:, 2]
	height = the_boxes[:, 3]
	confidence = the_boxes[:, 4]
	
	# 置信度從大到小排序
	index = confidecxe.argsort()[::-1]
	
	# 用於當前類別保留nms的結果box
	keep = []
	
	# 計算置信度最大box和其餘所有box的IOU,大於閾值的則從index中剔除,保留當前置信度最大的box
	# 在index中剔除剛纔保留的置信度最大的box,重複上述過程,直到index爲空
	# 所有保留下來的box就是nms後的結果
	while index.size > 0:
		best = index[0]
		keep.append(the_boxes[best, :])
		
		# 函數get_iou用於計算置信度最大的box和其餘所有box的IOU
		ious = get_iou(best, center_x, center_y, width, height)
		
		# thresh是nms的IOU閾值
		idx = np.where(ious <= thresh)[0]
		
		# 更新index,因爲計算idx時,去除了原始index中最大的值,所以這裏更新idx時要加1 
		index = index[idx + 1]

	result.append(keep)

python實現:

# -*- coding:utf-8 -*-

import numpy as np
import matplotlib.pyplot as plt


# 假設這個numpy是yolo網路的輸出
boxes = np.array([[155, 155, 110, 110, 0.72, 0.2, 0.9, 0.7],
                  [335, 335, 170, 170, 0.8, 0.3, 0.4, 0.8],
                  [270, 275, 100, 110, 0.92, 0.6, 0.8, 0.2],
                  [165, 255, 90, 110, 0.72, 0.1, 0.5, 0.4],
                  [277, 285, 95, 90, 0.81, 0.3, 0.9, 0.4],
                  [225, 225, 150, 150, 0.7, 0.3, 0.4, 0.7],
                  [350, 250, 100, 100, 0.8, 0.9, 0.2, 0.6],
                  [267, 285, 95, 110, 0.9, 0.2, 0.7, 0.6]])


def get_iou(index, best, center_x, center_y, width, height):
    x1 = center_x - width / 2
    y1 = center_y - height / 2
    x2 = center_x + width / 2
    y2 = center_y + height / 2
    areas = (y2 - y1 + 1) * (x2 - x1 + 1)

    x11 = np.maximum(x1[best], x1[index[1:]])
    y11 = np.maximum(y1[best], y1[index[1:]])
    x22 = np.minimum(x2[best], x2[index[1:]])
    y22 = np.minimum(y2[best], y2[index[1:]])

    # 如果邊框相交, x22 - x11 > 0, 如果邊框不相交, w(h)設爲0
    w = np.maximum(0, x22 - x11 + 1)
    h = np.maximum(0, y22 - y11 + 1)

    overlaps = w * h

    ious = overlaps / (areas[best] + areas[index[1:]] - overlaps)

    return ious


def nms(dets, thresh):
    """
    :param dets: numpy矩陣
    :param thresh: iou閾值
    :return:
    """

    result = []

    # 3類
    for each in range(3):
        the_boxes = dets[np.where(dets[:, 5:8].argsort()[:, -1] == each)[0].tolist(), :]

        center_x = the_boxes[:, 0]
        center_y = the_boxes[:, 1]
        width = the_boxes[:, 2]
        height = the_boxes[:, 3]
        confidence = the_boxes[:, 4]

        index = confidence.argsort()[::-1]

        keep = []

        while index.size > 0:
            best = index[0]
            keep.append(np.expand_dims(the_boxes[best, :], axis=0))

            ious = get_iou(index, best, center_x, center_y, width, height)

            idx = np.where(ious <= thresh)[0]

            index = index[idx + 1]

        result.append(np.concatenate(keep, axis=0))

    return np.concatenate(result, axis=0)


def plot_bbox(dets):
    center_x = dets[:, 0]
    center_y = dets[:, 1]
    width = dets[:, 2]
    height = dets[:, 3]

    class_id = dets[:, 5:8].argsort()[:, -1].tolist()

    color_list = ["lime", "magenta", "cyan"]
    for i, each in enumerate(class_id):
        x1 = int(center_x[i] - width[i] / 2)
        y1 = int(center_y[i] - height[i] / 2)
        x2 = int(center_x[i] + width[i] / 2)
        y2 = int(center_y[i] + height[i] / 2)

        c = color_list[each]

        plt.plot([x1, x2], [y1, y1], c)
        plt.plot([x1, x1], [y1, y2], c)
        plt.plot([x1, x2], [y2, y2], c)
        plt.plot([x2, x2], [y1, y2], c)


plt.figure(1)
ax1 = plt.subplot(1, 2, 1)
ax2 = plt.subplot(1, 2, 2)

plt.sca(ax1)

# nms之前的框
plot_bbox(boxes)

# nms之後的框
keep = nms(boxes, thresh=0.7)

plt.sca(ax2)
plot_bbox(keep)

plt.show()

在這裏插入圖片描述如上圖,左邊是未經過NMS處理的boxes,右邊是經過NMS處理的boxes。(不同的顏色代表不同的類別)

結語

如果您有修改意見或問題,歡迎留言或者通過郵箱和我聯繫。
手打很辛苦,如果我的文章對您有幫助,轉載請註明出處。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章