【深度学习】python实现NMS (非极大值抑制)

【深度学习】python实现NMS (非极大值抑制)

python实现NMS

首先需要明确一点,在多类别的目标检测任务中,NMS是发生在同一类别中的

NMS的流程:

# 以YOLO系列的目标检测算法为例, 网络输出tensor结构如下:
# [center_x, center_y, width, height, confidence, class1_score, class2_score, class3_score...]

# 假设一共有三个类别, 则输入tensor结构如下:
# boxes = np.array([center_x, center_y, width, height, confidence, class1_score, class2_score, class3_score],
#					[center_x, center_y, width, height, confidence, class1_score, class2_score, class3_score], 
# 					...
# 					[center_x, center_y, width, height, confidence, class1_score, class2_score, class3_score])

# 用于保留nms的结果box
result = []

# 对每个类别的输出分开处理
for each in range(类别数量):
	# 得到当前类别的所有输出
	the_boxes = boxes[np.where(boxes[:, 5:8].argsort()[:,-1] == each)[0].tolist(), :]

	center_x = the_boxes[:, 0]
	center_y = the_boxes[:, 1]
	width = the_boxes[:, 2]
	height = the_boxes[:, 3]
	confidence = the_boxes[:, 4]
	
	# 置信度从大到小排序
	index = confidecxe.argsort()[::-1]
	
	# 用于当前类别保留nms的结果box
	keep = []
	
	# 计算置信度最大box和其余所有box的IOU,大于阈值的则从index中剔除,保留当前置信度最大的box
	# 在index中剔除刚才保留的置信度最大的box,重复上述过程,直到index为空
	# 所有保留下来的box就是nms后的结果
	while index.size > 0:
		best = index[0]
		keep.append(the_boxes[best, :])
		
		# 函数get_iou用于计算置信度最大的box和其余所有box的IOU
		ious = get_iou(best, center_x, center_y, width, height)
		
		# thresh是nms的IOU阈值
		idx = np.where(ious <= thresh)[0]
		
		# 更新index,因为计算idx时,去除了原始index中最大的值,所以这里更新idx时要加1 
		index = index[idx + 1]

	result.append(keep)

python实现:

# -*- coding:utf-8 -*-

import numpy as np
import matplotlib.pyplot as plt


# 假设这个numpy是yolo网路的输出
boxes = np.array([[155, 155, 110, 110, 0.72, 0.2, 0.9, 0.7],
                  [335, 335, 170, 170, 0.8, 0.3, 0.4, 0.8],
                  [270, 275, 100, 110, 0.92, 0.6, 0.8, 0.2],
                  [165, 255, 90, 110, 0.72, 0.1, 0.5, 0.4],
                  [277, 285, 95, 90, 0.81, 0.3, 0.9, 0.4],
                  [225, 225, 150, 150, 0.7, 0.3, 0.4, 0.7],
                  [350, 250, 100, 100, 0.8, 0.9, 0.2, 0.6],
                  [267, 285, 95, 110, 0.9, 0.2, 0.7, 0.6]])


def get_iou(index, best, center_x, center_y, width, height):
    x1 = center_x - width / 2
    y1 = center_y - height / 2
    x2 = center_x + width / 2
    y2 = center_y + height / 2
    areas = (y2 - y1 + 1) * (x2 - x1 + 1)

    x11 = np.maximum(x1[best], x1[index[1:]])
    y11 = np.maximum(y1[best], y1[index[1:]])
    x22 = np.minimum(x2[best], x2[index[1:]])
    y22 = np.minimum(y2[best], y2[index[1:]])

    # 如果边框相交, x22 - x11 > 0, 如果边框不相交, w(h)设为0
    w = np.maximum(0, x22 - x11 + 1)
    h = np.maximum(0, y22 - y11 + 1)

    overlaps = w * h

    ious = overlaps / (areas[best] + areas[index[1:]] - overlaps)

    return ious


def nms(dets, thresh):
    """
    :param dets: numpy矩阵
    :param thresh: iou阈值
    :return:
    """

    result = []

    # 3类
    for each in range(3):
        the_boxes = dets[np.where(dets[:, 5:8].argsort()[:, -1] == each)[0].tolist(), :]

        center_x = the_boxes[:, 0]
        center_y = the_boxes[:, 1]
        width = the_boxes[:, 2]
        height = the_boxes[:, 3]
        confidence = the_boxes[:, 4]

        index = confidence.argsort()[::-1]

        keep = []

        while index.size > 0:
            best = index[0]
            keep.append(np.expand_dims(the_boxes[best, :], axis=0))

            ious = get_iou(index, best, center_x, center_y, width, height)

            idx = np.where(ious <= thresh)[0]

            index = index[idx + 1]

        result.append(np.concatenate(keep, axis=0))

    return np.concatenate(result, axis=0)


def plot_bbox(dets):
    center_x = dets[:, 0]
    center_y = dets[:, 1]
    width = dets[:, 2]
    height = dets[:, 3]

    class_id = dets[:, 5:8].argsort()[:, -1].tolist()

    color_list = ["lime", "magenta", "cyan"]
    for i, each in enumerate(class_id):
        x1 = int(center_x[i] - width[i] / 2)
        y1 = int(center_y[i] - height[i] / 2)
        x2 = int(center_x[i] + width[i] / 2)
        y2 = int(center_y[i] + height[i] / 2)

        c = color_list[each]

        plt.plot([x1, x2], [y1, y1], c)
        plt.plot([x1, x1], [y1, y2], c)
        plt.plot([x1, x2], [y2, y2], c)
        plt.plot([x2, x2], [y1, y2], c)


plt.figure(1)
ax1 = plt.subplot(1, 2, 1)
ax2 = plt.subplot(1, 2, 2)

plt.sca(ax1)

# nms之前的框
plot_bbox(boxes)

# nms之后的框
keep = nms(boxes, thresh=0.7)

plt.sca(ax2)
plot_bbox(keep)

plt.show()

在这里插入图片描述如上图,左边是未经过NMS处理的boxes,右边是经过NMS处理的boxes。(不同的颜色代表不同的类别)

结语

如果您有修改意见或问题,欢迎留言或者通过邮箱和我联系。
手打很辛苦,如果我的文章对您有帮助,转载请注明出处。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章