【庖丁解牛】从零实现RetinaNet(八):RetinaNet回归loss改进之GIoU、DIoU、CIoU

所有代码已上传到本人github repository:https://github.com/zgcr/pytorch-ImageNet-CIFAR-COCO-VOC-training
如果觉得有用,请点个star哟!
下列代码均在pytorch1.4版本中测试过,确认正确无误。

RetinaNet中沿用smooth l1 loss作为回归loss,但事实上,IoU也可以作为loss。但IoU有个缺点,就是两个不相交的框其IoU始终为0,而不能反映两个框之间距离的远近。因此,在IoU loss的基础上改进,又出现了GIoU loss、DIoU loss、CIoU loss。

GIoU loss

GIoU loss来自于这篇文章:https://arxiv.org/pdf/1902.09630.pdf 。GIoU在IoU基础上增加了闭包区域中不属于两个框的区域占闭包区域的比重,这可以用来衡量两个不相交的框之间的距离,在两个框不相交时,仍然可以为边界框优化提供移动方向。IoU取值范围[0,1],GIoU取值范围[-1,1]。作为GIoU loss时,要用1-GIoU,故取值范围为[0,2]。要注意的是我直接用GIoU loss替换smooth l1 loss,所学习的样本均为正样本(IoU一定不为0),因此实际优化时不存在两个框完全不相交的情况。
GIoU loss代码实现如下:

def compute_one_image_giou_loss(self, per_image_reg_heads,
                                    per_image_anchors,
                                    per_image_anchors_annotations):
        """
        compute one image giou loss(reg loss)
        per_image_reg_heads:[anchor_num,4]
        per_image_anchors:[anchor_num,4]
        per_image_anchors_annotations:[anchor_num,5]
        """
        # Filter anchors with gt class=-1, this part of anchor doesn't calculate giou loss
        device = per_image_reg_heads.device
        per_image_reg_heads = per_image_reg_heads[
            per_image_anchors_annotations[:, 4] > 0]
        per_image_anchors = per_image_anchors[
            per_image_anchors_annotations[:, 4] > 0]
        per_image_anchors_annotations = per_image_anchors_annotations[
            per_image_anchors_annotations[:, 4] > 0]
        positive_anchor_num = per_image_anchors_annotations.shape[0]

        if positive_anchor_num == 0:
            return torch.tensor(0.).to(device)

        per_image_anchors_w_h = per_image_anchors[:,
                                                  2:] - per_image_anchors[:, :2]
        per_image_anchors_ctr = per_image_anchors[:, :
                                                  2] + 0.5 * per_image_anchors_w_h

        factor = torch.tensor([[0.1, 0.1, 0.2, 0.2]]).to(device)
        per_image_reg_heads = per_image_reg_heads * factor

        pred_bboxes_wh = torch.exp(
            per_image_reg_heads[:, 2:]) * per_image_anchors_w_h
        pred_bboxes_ctr = per_image_reg_heads[:, :
                                              2] * per_image_anchors_w_h + per_image_anchors_ctr
        pred_bboxes_x_min_y_min = pred_bboxes_ctr - 0.5 * pred_bboxes_wh
        pred_bboxes_x_max_y_max = pred_bboxes_ctr + 0.5 * pred_bboxes_wh
        pred_bboxes = torch.cat(
            [pred_bboxes_x_min_y_min, pred_bboxes_x_max_y_max], axis=1)

        per_image_annotations = per_image_anchors_annotations[:, 0:4] * factor
        annotations_wh = torch.exp(
            per_image_annotations[:, 2:]) * per_image_anchors_w_h
        annotations_ctr = per_image_annotations[:, :
                                                2] * per_image_anchors_w_h + per_image_anchors_ctr
        annotations_x_min_y_min = annotations_ctr - 0.5 * annotations_wh
        annotations_x_max_y_max = annotations_ctr + 0.5 * annotations_wh
        annotations_bboxes = torch.cat(
            [annotations_x_min_y_min, annotations_x_max_y_max], axis=1)

        overlap_area_top_left = torch.max(pred_bboxes[:, 0:2],
                                          annotations_bboxes[:, 0:2])
        overlap_area_bot_right = torch.min(pred_bboxes[:, 2:4],
                                           annotations_bboxes[:, 2:4])
        overlap_area_sizes = torch.clamp(overlap_area_bot_right -
                                         overlap_area_top_left,
                                         min=0)
        overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:, 1]
        # anchors and annotations convert format to [x1,y1,w,h]
        pred_bboxes_w_h = pred_bboxes[:, 2:4] - pred_bboxes[:, 0:2] + 1
        annotations_w_h = annotations_bboxes[:,
                                             2:4] - annotations_bboxes[:,
                                                                       0:2] + 1
        # compute anchors_area and annotations_area
        pred_bboxes_area = pred_bboxes_w_h[:, 0] * pred_bboxes_w_h[:, 1]
        annotations_area = annotations_w_h[:, 0] * annotations_w_h[:, 1]

        # compute union_area
        union_area = pred_bboxes_area + annotations_area - overlap_area
        union_area = torch.clamp(union_area, min=1e-4)
        # compute ious between one image anchors and one image annotations
        ious = overlap_area / union_area

        enclose_area_top_left = torch.min(pred_bboxes[:, 0:2],
                                          annotations_bboxes[:, 0:2])
        enclose_area_bot_right = torch.max(pred_bboxes[:, 2:4],
                                           annotations_bboxes[:, 2:4])
        enclose_area_sizes = torch.clamp(enclose_area_bot_right -
                                         enclose_area_top_left,
                                         min=0)
        enclose_area = enclose_area_sizes[:, 0] * enclose_area_sizes[:, 1]
        enclose_area = torch.clamp(enclose_area, min=1e-4)

        gious_loss = 1. - ious + (enclose_area - union_area) / enclose_area
        gious_loss = gious_loss.sum() / positive_anchor_num
        gious_loss = 2.5 * gious_loss

        return gious_loss

注意最后乘以2.5只是为了平衡分类loss和回归loss。该loss使用时直接替代RetinaLoss类中compute_one_image_smoothl1_loss函数即可(forward函数for循环时需要额外取每张图的anchors)。

DIoU loss

DIoU loss和CIoU loss都来自这篇文章:https://arxiv.org/pdf/1911.08287.pdf 。DIoU loss在IoU loss的基础上增加了一项:两框中心点欧式距离的平方与两框最小闭包框对角线距离平方的比重。与GIoU loss类似,DIoU loss在预测框与目标框不重叠时,仍然可以为边界框提供优化移动方向。由于DIoU loss可以直接最小化两个目标框的距离,因此比GIoU loss收敛要快。同时,在NMS后处理中,DIoU还可以替换IoU评价策略,使得NMS得到的结果更加合理。DIoU loss也是用1-DIoU。
DIoU loss代码实现如下:

    def compute_one_image_diou_loss(self, per_image_reg_heads,
                                    per_image_anchors,
                                    per_image_anchors_annotations):
        """
        compute one image diou loss(reg loss)
        per_image_reg_heads:[anchor_num,4]
        per_image_anchors:[anchor_num,4]
        per_image_anchors_annotations:[anchor_num,5]
        """
        # Filter anchors with gt class=-1, this part of anchor doesn't calculate giou loss
        device = per_image_reg_heads.device
        per_image_reg_heads = per_image_reg_heads[
            per_image_anchors_annotations[:, 4] > 0]
        per_image_anchors = per_image_anchors[
            per_image_anchors_annotations[:, 4] > 0]
        per_image_anchors_annotations = per_image_anchors_annotations[
            per_image_anchors_annotations[:, 4] > 0]
        positive_anchor_num = per_image_anchors_annotations.shape[0]

        if positive_anchor_num == 0:
            return torch.tensor(0.).to(device)

        per_image_anchors_w_h = per_image_anchors[:,
                                                  2:] - per_image_anchors[:, :2]
        per_image_anchors_ctr = per_image_anchors[:, :
                                                  2] + 0.5 * per_image_anchors_w_h

        factor = torch.tensor([[0.1, 0.1, 0.2, 0.2]]).to(device)
        per_image_reg_heads = per_image_reg_heads * factor

        pred_bboxes_wh = torch.exp(
            per_image_reg_heads[:, 2:]) * per_image_anchors_w_h
        pred_bboxes_ctr = per_image_reg_heads[:, :
                                              2] * per_image_anchors_w_h + per_image_anchors_ctr
        pred_bboxes_x_min_y_min = pred_bboxes_ctr - 0.5 * pred_bboxes_wh
        pred_bboxes_x_max_y_max = pred_bboxes_ctr + 0.5 * pred_bboxes_wh
        pred_bboxes = torch.cat(
            [pred_bboxes_x_min_y_min, pred_bboxes_x_max_y_max], axis=1)

        per_image_annotations = per_image_anchors_annotations[:, 0:4] * factor
        annotations_wh = torch.exp(
            per_image_annotations[:, 2:]) * per_image_anchors_w_h
        annotations_ctr = per_image_annotations[:, :
                                                2] * per_image_anchors_w_h + per_image_anchors_ctr
        annotations_x_min_y_min = annotations_ctr - 0.5 * annotations_wh
        annotations_x_max_y_max = annotations_ctr + 0.5 * annotations_wh
        annotations_bboxes = torch.cat(
            [annotations_x_min_y_min, annotations_x_max_y_max], axis=1)

        overlap_area_top_left = torch.max(pred_bboxes[:, 0:2],
                                          annotations_bboxes[:, 0:2])
        overlap_area_bot_right = torch.min(pred_bboxes[:, 2:4],
                                           annotations_bboxes[:, 2:4])
        overlap_area_sizes = torch.clamp(overlap_area_bot_right -
                                         overlap_area_top_left,
                                         min=0)
        overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:, 1]
        # anchors and annotations convert format to [x1,y1,w,h]
        pred_bboxes_w_h = pred_bboxes[:, 2:4] - pred_bboxes[:, 0:2] + 1
        annotations_w_h = annotations_bboxes[:,
                                             2:4] - annotations_bboxes[:,
                                                                       0:2] + 1
        # compute anchors_area and annotations_area
        pred_bboxes_area = pred_bboxes_w_h[:, 0] * pred_bboxes_w_h[:, 1]
        annotations_area = annotations_w_h[:, 0] * annotations_w_h[:, 1]

        # compute union_area
        union_area = pred_bboxes_area + annotations_area - overlap_area
        union_area = torch.clamp(union_area, min=1e-4)
        # compute ious between one image anchors and one image annotations
        ious = overlap_area / union_area

        pred_bboxes_ctr = (pred_bboxes[:, 2:4] + pred_bboxes[:, 0:2]) / 2
        annotations_bboxes_ctr = (annotations_bboxes[:, 2:4] +
                                  annotations_bboxes[:, 0:2]) / 2
        p2 = (pred_bboxes_ctr[:, 0] - annotations_bboxes_ctr[:, 0])**2 + (
            pred_bboxes_ctr[:, 1] - annotations_bboxes_ctr[:, 1])**2

        enclose_area_top_left = torch.min(pred_bboxes[:, 0:2],
                                          annotations_bboxes[:, 0:2])
        enclose_area_bot_right = torch.max(pred_bboxes[:, 2:4],
                                           annotations_bboxes[:, 2:4])
        enclose_area_sizes = torch.clamp(enclose_area_bot_right -
                                         enclose_area_top_left,
                                         min=1e-2)
        c2 = (enclose_area_sizes[:, 0])**2 + (enclose_area_sizes[:, 1])**2

        dious_loss = 1. - ious + p2 / c2
        dious_loss = dious_loss.sum() / positive_anchor_num
        dious_loss = 2. * dious_loss

        return dious_loss

注意最后乘以2只是为了平衡分类loss和回归loss。该loss使用时直接替代RetinaLoss类中compute_one_image_smoothl1_loss函数即可(forward函数for循环时需要额外取每张图的anchors)。
同时,在论文中还提到可以将DIoU替代NMS中的IoU,后处理效果更好。
DIoU-NMS代码实现如下:

    def nms(self, one_image_scores, one_image_classes, one_image_pred_bboxes):
        """
        one_image_scores:[anchor_nums],4:classification predict scores
        one_image_classes:[anchor_nums],class indexes for predict scores
        one_image_pred_bboxes:[anchor_nums,4],4:x_min,y_min,x_max,y_max
        """
        # Sort boxes
        sorted_one_image_scores, sorted_one_image_scores_indexes = torch.sort(
            one_image_scores, descending=True)
        sorted_one_image_classes = one_image_classes[
            sorted_one_image_scores_indexes]
        sorted_one_image_pred_bboxes = one_image_pred_bboxes[
            sorted_one_image_scores_indexes]
        sorted_pred_bboxes_w_h = sorted_one_image_pred_bboxes[:,
                                                              2:] - sorted_one_image_pred_bboxes[:, :
                                                                                                 2]

        sorted_pred_bboxes_areas = sorted_pred_bboxes_w_h[:,
                                                          0] * sorted_pred_bboxes_w_h[:,
                                                                                      1]
        detected_classes = torch.unique(sorted_one_image_classes, sorted=True)

        keep_scores, keep_classes, keep_pred_bboxes = [], [], []
        for detected_class in detected_classes:
            single_class_scores = sorted_one_image_scores[
                sorted_one_image_classes == detected_class]
            single_class_pred_bboxes = sorted_one_image_pred_bboxes[
                sorted_one_image_classes == detected_class]
            single_class_pred_bboxes_areas = sorted_pred_bboxes_areas[
                sorted_one_image_classes == detected_class]
            single_class = sorted_one_image_classes[sorted_one_image_classes ==
                                                    detected_class]

            single_keep_scores,single_keep_classes,single_keep_pred_bboxes=[],[],[]
            while single_class_scores.numel() > 0:
                top1_score, top1_class, top1_pred_bbox = single_class_scores[
                    0:1], single_class[0:1], single_class_pred_bboxes[0:1]

                single_keep_scores.append(top1_score)
                single_keep_classes.append(top1_class)
                single_keep_pred_bboxes.append(top1_pred_bbox)

                top1_areas = single_class_pred_bboxes_areas[0]

                if single_class_scores.numel() == 1:
                    break

                single_class_scores = single_class_scores[1:]
                single_class = single_class[1:]
                single_class_pred_bboxes = single_class_pred_bboxes[1:]
                single_class_pred_bboxes_areas = single_class_pred_bboxes_areas[
                    1:]

                overlap_area_top_left = torch.max(
                    single_class_pred_bboxes[:, :2], top1_pred_bbox[:, :2])
                overlap_area_bot_right = torch.min(
                    single_class_pred_bboxes[:, 2:], top1_pred_bbox[:, 2:])
                overlap_area_sizes = torch.clamp(overlap_area_bot_right -
                                                 overlap_area_top_left,
                                                 min=0)
                overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:,
                                                                             1]

                # compute union_area
                union_area = top1_areas + single_class_pred_bboxes_areas - overlap_area
                union_area = torch.clamp(union_area, min=1e-4)
                # compute ious for top1 pred_bbox and the other pred_bboxes
                ious = overlap_area / union_area

                top1_pred_bbox_ctr = (top1_pred_bbox[:, 2:4] +
                                      top1_pred_bbox[:, 0:2]) / 2
                single_class_pred_bboxes_ctr = (
                    single_class_pred_bboxes[:, 2:4] +
                    single_class_pred_bboxes[:, 0:2]) / 2
                p2 = (top1_pred_bbox_ctr[:, 0] -
                      single_class_pred_bboxes_ctr[:, 0])**2 + (
                          top1_pred_bbox_ctr[:, 1] -
                          single_class_pred_bboxes_ctr[:, 1])**2

                enclose_area_top_left = torch.min(
                    top1_pred_bbox[:, 0:2], single_class_pred_bboxes[:, 0:2])
                enclose_area_bot_right = torch.max(
                    top1_pred_bbox[:, 2:4], single_class_pred_bboxes[:, 2:4])
                enclose_area_sizes = torch.clamp(enclose_area_bot_right -
                                                 enclose_area_top_left,
                                                 min=1e-4)
                c2 = (enclose_area_sizes[:, 0])**2 + (enclose_area_sizes[:,
                                                                         1])**2

                dious = ious - p2 / c2

                single_class_scores = single_class_scores[
                    dious < self.nms_threshold]
                single_class = single_class[dious < self.nms_threshold]
                single_class_pred_bboxes = single_class_pred_bboxes[
                    dious < self.nms_threshold]
                single_class_pred_bboxes_areas = single_class_pred_bboxes_areas[
                    dious < self.nms_threshold]

            single_keep_scores = torch.cat(single_keep_scores, axis=0)
            single_keep_classes = torch.cat(single_keep_classes, axis=0)
            single_keep_pred_bboxes = torch.cat(single_keep_pred_bboxes,
                                                axis=0)

            keep_scores.append(single_keep_scores)
            keep_classes.append(single_keep_classes)
            keep_pred_bboxes.append(single_keep_pred_bboxes)

        keep_scores = torch.cat(keep_scores, axis=0)
        keep_classes = torch.cat(keep_classes, axis=0)
        keep_pred_bboxes = torch.cat(keep_pred_bboxes, axis=0)

        return keep_scores, keep_classes, keep_pred_bboxes

该NMS函数直接替代RetinaDecoder类中的NMS函数即可。根据论文中贴出的实验结果,替代后nms_threshold不变效果也会更好。

CIoU loss

DIoU loss和CIoU loss都来自这篇文章:https://arxiv.org/pdf/1911.08287.pdf 。CIoU loss在DIoU loss的基础上增加了一项av(a为权重系数,v用来度量长宽比的相似性),用来衡量预测框和目标框的长宽比,使得回归方向更加合理。同样,在NMS后处理中,CIoU也可以替换IoU评价策略,使得NMS得到的结果更加合理。
CIoU loss代码实现如下:

    def compute_one_image_ciou_loss(self, per_image_reg_heads,
                                    per_image_anchors,
                                    per_image_anchors_annotations):
        """
        compute one image ciou loss(reg loss)
        per_image_reg_heads:[anchor_num,4]
        per_image_anchors:[anchor_num,4]
        per_image_anchors_annotations:[anchor_num,5]
        """
        # Filter anchors with gt class=-1, this part of anchor doesn't calculate giou loss
        device = per_image_reg_heads.device
        per_image_reg_heads = per_image_reg_heads[
            per_image_anchors_annotations[:, 4] > 0]
        per_image_anchors = per_image_anchors[
            per_image_anchors_annotations[:, 4] > 0]
        per_image_anchors_annotations = per_image_anchors_annotations[
            per_image_anchors_annotations[:, 4] > 0]
        positive_anchor_num = per_image_anchors_annotations.shape[0]

        if positive_anchor_num == 0:
            return torch.tensor(0.).to(device)

        per_image_anchors_w_h = per_image_anchors[:,
                                                  2:] - per_image_anchors[:, :2]
        per_image_anchors_ctr = per_image_anchors[:, :
                                                  2] + 0.5 * per_image_anchors_w_h

        factor = torch.tensor([[0.1, 0.1, 0.2, 0.2]]).to(device)
        per_image_reg_heads = per_image_reg_heads * factor

        pred_bboxes_wh = torch.exp(
            per_image_reg_heads[:, 2:]) * per_image_anchors_w_h
        pred_bboxes_ctr = per_image_reg_heads[:, :
                                              2] * per_image_anchors_w_h + per_image_anchors_ctr
        pred_bboxes_x_min_y_min = pred_bboxes_ctr - 0.5 * pred_bboxes_wh
        pred_bboxes_x_max_y_max = pred_bboxes_ctr + 0.5 * pred_bboxes_wh
        pred_bboxes = torch.cat(
            [pred_bboxes_x_min_y_min, pred_bboxes_x_max_y_max], axis=1)

        per_image_annotations = per_image_anchors_annotations[:, 0:4] * factor
        annotations_wh = torch.exp(
            per_image_annotations[:, 2:]) * per_image_anchors_w_h
        annotations_ctr = per_image_annotations[:, :
                                                2] * per_image_anchors_w_h + per_image_anchors_ctr
        annotations_x_min_y_min = annotations_ctr - 0.5 * annotations_wh
        annotations_x_max_y_max = annotations_ctr + 0.5 * annotations_wh
        annotations_bboxes = torch.cat(
            [annotations_x_min_y_min, annotations_x_max_y_max], axis=1)

        overlap_area_top_left = torch.max(pred_bboxes[:, 0:2],
                                          annotations_bboxes[:, 0:2])
        overlap_area_bot_right = torch.min(pred_bboxes[:, 2:4],
                                           annotations_bboxes[:, 2:4])
        overlap_area_sizes = torch.clamp(overlap_area_bot_right -
                                         overlap_area_top_left,
                                         min=0)
        overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:, 1]
        # anchors and annotations convert format to [x1,y1,w,h]
        pred_bboxes_w_h = pred_bboxes[:, 2:4] - pred_bboxes[:, 0:2] + 1
        annotations_w_h = annotations_bboxes[:,
                                             2:4] - annotations_bboxes[:,
                                                                       0:2] + 1
        # compute anchors_area and annotations_area
        pred_bboxes_area = pred_bboxes_w_h[:, 0] * pred_bboxes_w_h[:, 1]
        annotations_area = annotations_w_h[:, 0] * annotations_w_h[:, 1]

        # compute union_area
        union_area = pred_bboxes_area + annotations_area - overlap_area
        union_area = torch.clamp(union_area, min=1e-4)
        # compute ious between one image anchors and one image annotations
        ious = overlap_area / union_area

        pred_bboxes_ctr = (pred_bboxes[:, 2:4] + pred_bboxes[:, 0:2]) / 2
        annotations_bboxes_ctr = (annotations_bboxes[:, 2:4] +
                                  annotations_bboxes[:, 0:2]) / 2
        p2 = (pred_bboxes_ctr[:, 0] - annotations_bboxes_ctr[:, 0])**2 + (
            pred_bboxes_ctr[:, 1] - annotations_bboxes_ctr[:, 1])**2

        enclose_area_top_left = torch.min(pred_bboxes[:, 0:2],
                                          annotations_bboxes[:, 0:2])
        enclose_area_bot_right = torch.max(pred_bboxes[:, 2:4],
                                           annotations_bboxes[:, 2:4])
        enclose_area_sizes = torch.clamp(enclose_area_bot_right -
                                         enclose_area_top_left,
                                         min=1e-2)
        c2 = (enclose_area_sizes[:, 0])**2 + (enclose_area_sizes[:, 1])**2

        with torch.no_grad():
            v = torch.pow(
            (torch.atan(annotations_w_h[:, 0] / annotations_w_h[:, 1]) -
             torch.atan(pred_bboxes_w_h[:, 0] / pred_bboxes_w_h[:, 1])),
            2) * (4 / (math.pi**2))
            alpha = v / (1 - ious + v)

        v = torch.pow(
            (torch.atan(annotations_w_h[:, 0] / annotations_w_h[:, 1]) -
             torch.atan(pred_bboxes_w_h[:, 0] / pred_bboxes_w_h[:, 1])),
            2) * (4 / (math.pi**2))
        
        cious_loss = 1. - ious + p2 / c2 + alpha * v
        cious_loss = cious_loss.sum() / positive_anchor_num
        cious_loss = 2. * cious_loss

        return cious_loss

注意alpha作为权重系数不回传梯度,最后乘以2只是为了平衡分类loss和回归loss。该loss使用时直接替代RetinaLoss类中compute_one_image_smoothl1_loss函数即可(forward函数for循环时需要额外取每张图的anchors)。
同时,在论文中还提到可以将DIoU替代NMS中的IoU,后处理效果更好。我也尝试了一下用CIoU替代NMS中的IoU。
CIoU-NMS代码实现如下:

    def nms(self, one_image_scores, one_image_classes, one_image_pred_bboxes):
        """
        one_image_scores:[anchor_nums],4:classification predict scores
        one_image_classes:[anchor_nums],class indexes for predict scores
        one_image_pred_bboxes:[anchor_nums,4],4:x_min,y_min,x_max,y_max
        """
        # Sort boxes
        sorted_one_image_scores, sorted_one_image_scores_indexes = torch.sort(
            one_image_scores, descending=True)
        sorted_one_image_classes = one_image_classes[
            sorted_one_image_scores_indexes]
        sorted_one_image_pred_bboxes = one_image_pred_bboxes[
            sorted_one_image_scores_indexes]
        sorted_pred_bboxes_w_h = sorted_one_image_pred_bboxes[:,
                                                              2:] - sorted_one_image_pred_bboxes[:, :
                                                                                                 2]

        sorted_pred_bboxes_areas = sorted_pred_bboxes_w_h[:,
                                                          0] * sorted_pred_bboxes_w_h[:,
                                                                                      1]
        detected_classes = torch.unique(sorted_one_image_classes, sorted=True)

        keep_scores, keep_classes, keep_pred_bboxes = [], [], []
        for detected_class in detected_classes:
            single_class_scores = sorted_one_image_scores[
                sorted_one_image_classes == detected_class]
            single_class_pred_bboxes = sorted_one_image_pred_bboxes[
                sorted_one_image_classes == detected_class]
            single_class_pred_bboxes_areas = sorted_pred_bboxes_areas[
                sorted_one_image_classes == detected_class]
            single_class = sorted_one_image_classes[sorted_one_image_classes ==
                                                    detected_class]

            single_keep_scores,single_keep_classes,single_keep_pred_bboxes=[],[],[]
            while single_class_scores.numel() > 0:
                top1_score, top1_class, top1_pred_bbox = single_class_scores[
                    0:1], single_class[0:1], single_class_pred_bboxes[0:1]

                single_keep_scores.append(top1_score)
                single_keep_classes.append(top1_class)
                single_keep_pred_bboxes.append(top1_pred_bbox)

                top1_areas = single_class_pred_bboxes_areas[0]

                if single_class_scores.numel() == 1:
                    break

                single_class_scores = single_class_scores[1:]
                single_class = single_class[1:]
                single_class_pred_bboxes = single_class_pred_bboxes[1:]
                single_class_pred_bboxes_areas = single_class_pred_bboxes_areas[
                    1:]

                overlap_area_top_left = torch.max(
                    single_class_pred_bboxes[:, :2], top1_pred_bbox[:, :2])
                overlap_area_bot_right = torch.min(
                    single_class_pred_bboxes[:, 2:], top1_pred_bbox[:, 2:])
                overlap_area_sizes = torch.clamp(overlap_area_bot_right -
                                                 overlap_area_top_left,
                                                 min=0)
                overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:,
                                                                             1]

                # compute union_area
                union_area = top1_areas + single_class_pred_bboxes_areas - overlap_area
                union_area = torch.clamp(union_area, min=1e-4)
                # compute ious for top1 pred_bbox and the other pred_bboxes
                ious = overlap_area / union_area

                top1_pred_bbox_ctr = (top1_pred_bbox[:, 2:4] +
                                      top1_pred_bbox[:, 0:2]) / 2
                single_class_pred_bboxes_ctr = (
                    single_class_pred_bboxes[:, 2:4] +
                    single_class_pred_bboxes[:, 0:2]) / 2
                p2 = (top1_pred_bbox_ctr[:, 0] -
                      single_class_pred_bboxes_ctr[:, 0])**2 + (
                          top1_pred_bbox_ctr[:, 1] -
                          single_class_pred_bboxes_ctr[:, 1])**2

                enclose_area_top_left = torch.min(
                    top1_pred_bbox[:, 0:2], single_class_pred_bboxes[:, 0:2])
                enclose_area_bot_right = torch.max(
                    top1_pred_bbox[:, 2:4], single_class_pred_bboxes[:, 2:4])
                enclose_area_sizes = torch.clamp(enclose_area_bot_right -
                                                 enclose_area_top_left,
                                                 min=1e-4)
                c2 = (enclose_area_sizes[:, 0])**2 + (enclose_area_sizes[:,
                                                                         1])**2

                top1_pred_bbox_wh = top1_pred_bbox[:,
                                                   2:4] - top1_pred_bbox[:,
                                                                         0:2]
                single_class_pred_bboxes_wh = single_class_pred_bboxes[:, 2:
                                                                       4] - single_class_pred_bboxes[:,
                                                                                                     0:
                                                                                                     2]
                v = torch.pow(
                    (torch.atan(single_class_pred_bboxes_wh[:, 0] /
                                single_class_pred_bboxes_wh[:, 1]) -
                     torch.atan(top1_pred_bbox_wh[:, 0] /
                                top1_pred_bbox_wh[:, 1])), 2) * (4 /
                                                                 (math.pi**2))
                alpha = v / (1 - ious + v)

                cious = ious - p2 / c2 - alpha * v

                single_class_scores = single_class_scores[
                    cious < self.nms_threshold]
                single_class = single_class[cious < self.nms_threshold]
                single_class_pred_bboxes = single_class_pred_bboxes[
                    cious < self.nms_threshold]
                single_class_pred_bboxes_areas = single_class_pred_bboxes_areas[
                    cious < self.nms_threshold]

            single_keep_scores = torch.cat(single_keep_scores, axis=0)
            single_keep_classes = torch.cat(single_keep_classes, axis=0)
            single_keep_pred_bboxes = torch.cat(single_keep_pred_bboxes,
                                                axis=0)

            keep_scores.append(single_keep_scores)
            keep_classes.append(single_keep_classes)
            keep_pred_bboxes.append(single_keep_pred_bboxes)

        keep_scores = torch.cat(keep_scores, axis=0)
        keep_classes = torch.cat(keep_classes, axis=0)
        keep_pred_bboxes = torch.cat(keep_pred_bboxes, axis=0)

        return keep_scores, keep_classes, keep_pred_bboxes

该NMS函数直接替代RetinaDecoder类中的NMS函数即可。根据论文中贴出的实验结果,替代后nms_threshold不变效果也会更好。

实验结果

实验以从零实现RetinaNet(七)中最后的ResNet50-RetinaNet-aug-iscrowd作为baseline。输入分辨率600,约等于RetinaNet论文中的分辨率450。

Network batch gpu-num apex syncbn epoch5-mAP-mAR-loss epoch5-mAP-mAR-loss epoch12-mAP-mAR-loss
baseline 15 1 yes no 0.254,0.394,0.62 0.280,0.418,0.53 0.286,0.421,0.50
baseline-giou 15 1 yes no 0.259,0.400,0.68 0.284,0.425,0.59 0.292,0.429,0.57
baseline-diou 15 1 yes no 0.257,0.399,0.67 0.285,0.422,0.59 0.291,0.424,0.57
baseline-diou-nms 15 1 yes no 0.256,0.395,0.67 0.286,0.424,0.59 0.290,0.428,0.57
baseline-ciou 15 1 yes no 0.263,0.400,0.67 0.287,0.420,0.59 0.292,0.426,0.57
baseline-ciou-nms 15 1 yes no 0.257,0.401,0.67 0.284,0.418,0.59 0.292,0.429,0.57

baseline-giou代表回归loss使用GIoU loss。baseline-diou-nms代表回归loss使用DIoU loss且NMS中IoU替换为DIoU。baseline-ciou-nms代表回归loss使用CIoU loss且NMS中IoU替换为CIoU。根据结果来看,使用GIoU、DIoU、CIoU后模型的表现要比使用smooth l1 loss要好,但NMS中替换IoU评价指标为DIoU和CIoU时没有明显涨点。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章