【庖丁解牛】從零實現RetinaNet(八):RetinaNet迴歸loss改進之GIoU、DIoU、CIoU

所有代碼已上傳到本人github repository:https://github.com/zgcr/pytorch-ImageNet-CIFAR-COCO-VOC-training
如果覺得有用,請點個star喲!
下列代碼均在pytorch1.4版本中測試過,確認正確無誤。

RetinaNet中沿用smooth l1 loss作爲迴歸loss,但事實上,IoU也可以作爲loss。但IoU有個缺點,就是兩個不相交的框其IoU始終爲0,而不能反映兩個框之間距離的遠近。因此,在IoU loss的基礎上改進,又出現了GIoU loss、DIoU loss、CIoU loss。

GIoU loss

GIoU loss來自於這篇文章:https://arxiv.org/pdf/1902.09630.pdf 。GIoU在IoU基礎上增加了閉包區域中不屬於兩個框的區域佔閉包區域的比重,這可以用來衡量兩個不相交的框之間的距離,在兩個框不相交時,仍然可以爲邊界框優化提供移動方向。IoU取值範圍[0,1],GIoU取值範圍[-1,1]。作爲GIoU loss時,要用1-GIoU,故取值範圍爲[0,2]。要注意的是我直接用GIoU loss替換smooth l1 loss,所學習的樣本均爲正樣本(IoU一定不爲0),因此實際優化時不存在兩個框完全不相交的情況。
GIoU loss代碼實現如下:

def compute_one_image_giou_loss(self, per_image_reg_heads,
                                    per_image_anchors,
                                    per_image_anchors_annotations):
        """
        compute one image giou loss(reg loss)
        per_image_reg_heads:[anchor_num,4]
        per_image_anchors:[anchor_num,4]
        per_image_anchors_annotations:[anchor_num,5]
        """
        # Filter anchors with gt class=-1, this part of anchor doesn't calculate giou loss
        device = per_image_reg_heads.device
        per_image_reg_heads = per_image_reg_heads[
            per_image_anchors_annotations[:, 4] > 0]
        per_image_anchors = per_image_anchors[
            per_image_anchors_annotations[:, 4] > 0]
        per_image_anchors_annotations = per_image_anchors_annotations[
            per_image_anchors_annotations[:, 4] > 0]
        positive_anchor_num = per_image_anchors_annotations.shape[0]

        if positive_anchor_num == 0:
            return torch.tensor(0.).to(device)

        per_image_anchors_w_h = per_image_anchors[:,
                                                  2:] - per_image_anchors[:, :2]
        per_image_anchors_ctr = per_image_anchors[:, :
                                                  2] + 0.5 * per_image_anchors_w_h

        factor = torch.tensor([[0.1, 0.1, 0.2, 0.2]]).to(device)
        per_image_reg_heads = per_image_reg_heads * factor

        pred_bboxes_wh = torch.exp(
            per_image_reg_heads[:, 2:]) * per_image_anchors_w_h
        pred_bboxes_ctr = per_image_reg_heads[:, :
                                              2] * per_image_anchors_w_h + per_image_anchors_ctr
        pred_bboxes_x_min_y_min = pred_bboxes_ctr - 0.5 * pred_bboxes_wh
        pred_bboxes_x_max_y_max = pred_bboxes_ctr + 0.5 * pred_bboxes_wh
        pred_bboxes = torch.cat(
            [pred_bboxes_x_min_y_min, pred_bboxes_x_max_y_max], axis=1)

        per_image_annotations = per_image_anchors_annotations[:, 0:4] * factor
        annotations_wh = torch.exp(
            per_image_annotations[:, 2:]) * per_image_anchors_w_h
        annotations_ctr = per_image_annotations[:, :
                                                2] * per_image_anchors_w_h + per_image_anchors_ctr
        annotations_x_min_y_min = annotations_ctr - 0.5 * annotations_wh
        annotations_x_max_y_max = annotations_ctr + 0.5 * annotations_wh
        annotations_bboxes = torch.cat(
            [annotations_x_min_y_min, annotations_x_max_y_max], axis=1)

        overlap_area_top_left = torch.max(pred_bboxes[:, 0:2],
                                          annotations_bboxes[:, 0:2])
        overlap_area_bot_right = torch.min(pred_bboxes[:, 2:4],
                                           annotations_bboxes[:, 2:4])
        overlap_area_sizes = torch.clamp(overlap_area_bot_right -
                                         overlap_area_top_left,
                                         min=0)
        overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:, 1]
        # anchors and annotations convert format to [x1,y1,w,h]
        pred_bboxes_w_h = pred_bboxes[:, 2:4] - pred_bboxes[:, 0:2] + 1
        annotations_w_h = annotations_bboxes[:,
                                             2:4] - annotations_bboxes[:,
                                                                       0:2] + 1
        # compute anchors_area and annotations_area
        pred_bboxes_area = pred_bboxes_w_h[:, 0] * pred_bboxes_w_h[:, 1]
        annotations_area = annotations_w_h[:, 0] * annotations_w_h[:, 1]

        # compute union_area
        union_area = pred_bboxes_area + annotations_area - overlap_area
        union_area = torch.clamp(union_area, min=1e-4)
        # compute ious between one image anchors and one image annotations
        ious = overlap_area / union_area

        enclose_area_top_left = torch.min(pred_bboxes[:, 0:2],
                                          annotations_bboxes[:, 0:2])
        enclose_area_bot_right = torch.max(pred_bboxes[:, 2:4],
                                           annotations_bboxes[:, 2:4])
        enclose_area_sizes = torch.clamp(enclose_area_bot_right -
                                         enclose_area_top_left,
                                         min=0)
        enclose_area = enclose_area_sizes[:, 0] * enclose_area_sizes[:, 1]
        enclose_area = torch.clamp(enclose_area, min=1e-4)

        gious_loss = 1. - ious + (enclose_area - union_area) / enclose_area
        gious_loss = gious_loss.sum() / positive_anchor_num
        gious_loss = 2.5 * gious_loss

        return gious_loss

注意最後乘以2.5只是爲了平衡分類loss和迴歸loss。該loss使用時直接替代RetinaLoss類中compute_one_image_smoothl1_loss函數即可(forward函數for循環時需要額外取每張圖的anchors)。

DIoU loss

DIoU loss和CIoU loss都來自這篇文章:https://arxiv.org/pdf/1911.08287.pdf 。DIoU loss在IoU loss的基礎上增加了一項:兩框中心點歐式距離的平方與兩框最小閉包框對角線距離平方的比重。與GIoU loss類似,DIoU loss在預測框與目標框不重疊時,仍然可以爲邊界框提供優化移動方向。由於DIoU loss可以直接最小化兩個目標框的距離,因此比GIoU loss收斂要快。同時,在NMS後處理中,DIoU還可以替換IoU評價策略,使得NMS得到的結果更加合理。DIoU loss也是用1-DIoU。
DIoU loss代碼實現如下:

    def compute_one_image_diou_loss(self, per_image_reg_heads,
                                    per_image_anchors,
                                    per_image_anchors_annotations):
        """
        compute one image diou loss(reg loss)
        per_image_reg_heads:[anchor_num,4]
        per_image_anchors:[anchor_num,4]
        per_image_anchors_annotations:[anchor_num,5]
        """
        # Filter anchors with gt class=-1, this part of anchor doesn't calculate giou loss
        device = per_image_reg_heads.device
        per_image_reg_heads = per_image_reg_heads[
            per_image_anchors_annotations[:, 4] > 0]
        per_image_anchors = per_image_anchors[
            per_image_anchors_annotations[:, 4] > 0]
        per_image_anchors_annotations = per_image_anchors_annotations[
            per_image_anchors_annotations[:, 4] > 0]
        positive_anchor_num = per_image_anchors_annotations.shape[0]

        if positive_anchor_num == 0:
            return torch.tensor(0.).to(device)

        per_image_anchors_w_h = per_image_anchors[:,
                                                  2:] - per_image_anchors[:, :2]
        per_image_anchors_ctr = per_image_anchors[:, :
                                                  2] + 0.5 * per_image_anchors_w_h

        factor = torch.tensor([[0.1, 0.1, 0.2, 0.2]]).to(device)
        per_image_reg_heads = per_image_reg_heads * factor

        pred_bboxes_wh = torch.exp(
            per_image_reg_heads[:, 2:]) * per_image_anchors_w_h
        pred_bboxes_ctr = per_image_reg_heads[:, :
                                              2] * per_image_anchors_w_h + per_image_anchors_ctr
        pred_bboxes_x_min_y_min = pred_bboxes_ctr - 0.5 * pred_bboxes_wh
        pred_bboxes_x_max_y_max = pred_bboxes_ctr + 0.5 * pred_bboxes_wh
        pred_bboxes = torch.cat(
            [pred_bboxes_x_min_y_min, pred_bboxes_x_max_y_max], axis=1)

        per_image_annotations = per_image_anchors_annotations[:, 0:4] * factor
        annotations_wh = torch.exp(
            per_image_annotations[:, 2:]) * per_image_anchors_w_h
        annotations_ctr = per_image_annotations[:, :
                                                2] * per_image_anchors_w_h + per_image_anchors_ctr
        annotations_x_min_y_min = annotations_ctr - 0.5 * annotations_wh
        annotations_x_max_y_max = annotations_ctr + 0.5 * annotations_wh
        annotations_bboxes = torch.cat(
            [annotations_x_min_y_min, annotations_x_max_y_max], axis=1)

        overlap_area_top_left = torch.max(pred_bboxes[:, 0:2],
                                          annotations_bboxes[:, 0:2])
        overlap_area_bot_right = torch.min(pred_bboxes[:, 2:4],
                                           annotations_bboxes[:, 2:4])
        overlap_area_sizes = torch.clamp(overlap_area_bot_right -
                                         overlap_area_top_left,
                                         min=0)
        overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:, 1]
        # anchors and annotations convert format to [x1,y1,w,h]
        pred_bboxes_w_h = pred_bboxes[:, 2:4] - pred_bboxes[:, 0:2] + 1
        annotations_w_h = annotations_bboxes[:,
                                             2:4] - annotations_bboxes[:,
                                                                       0:2] + 1
        # compute anchors_area and annotations_area
        pred_bboxes_area = pred_bboxes_w_h[:, 0] * pred_bboxes_w_h[:, 1]
        annotations_area = annotations_w_h[:, 0] * annotations_w_h[:, 1]

        # compute union_area
        union_area = pred_bboxes_area + annotations_area - overlap_area
        union_area = torch.clamp(union_area, min=1e-4)
        # compute ious between one image anchors and one image annotations
        ious = overlap_area / union_area

        pred_bboxes_ctr = (pred_bboxes[:, 2:4] + pred_bboxes[:, 0:2]) / 2
        annotations_bboxes_ctr = (annotations_bboxes[:, 2:4] +
                                  annotations_bboxes[:, 0:2]) / 2
        p2 = (pred_bboxes_ctr[:, 0] - annotations_bboxes_ctr[:, 0])**2 + (
            pred_bboxes_ctr[:, 1] - annotations_bboxes_ctr[:, 1])**2

        enclose_area_top_left = torch.min(pred_bboxes[:, 0:2],
                                          annotations_bboxes[:, 0:2])
        enclose_area_bot_right = torch.max(pred_bboxes[:, 2:4],
                                           annotations_bboxes[:, 2:4])
        enclose_area_sizes = torch.clamp(enclose_area_bot_right -
                                         enclose_area_top_left,
                                         min=1e-2)
        c2 = (enclose_area_sizes[:, 0])**2 + (enclose_area_sizes[:, 1])**2

        dious_loss = 1. - ious + p2 / c2
        dious_loss = dious_loss.sum() / positive_anchor_num
        dious_loss = 2. * dious_loss

        return dious_loss

注意最後乘以2只是爲了平衡分類loss和迴歸loss。該loss使用時直接替代RetinaLoss類中compute_one_image_smoothl1_loss函數即可(forward函數for循環時需要額外取每張圖的anchors)。
同時,在論文中還提到可以將DIoU替代NMS中的IoU,後處理效果更好。
DIoU-NMS代碼實現如下:

    def nms(self, one_image_scores, one_image_classes, one_image_pred_bboxes):
        """
        one_image_scores:[anchor_nums],4:classification predict scores
        one_image_classes:[anchor_nums],class indexes for predict scores
        one_image_pred_bboxes:[anchor_nums,4],4:x_min,y_min,x_max,y_max
        """
        # Sort boxes
        sorted_one_image_scores, sorted_one_image_scores_indexes = torch.sort(
            one_image_scores, descending=True)
        sorted_one_image_classes = one_image_classes[
            sorted_one_image_scores_indexes]
        sorted_one_image_pred_bboxes = one_image_pred_bboxes[
            sorted_one_image_scores_indexes]
        sorted_pred_bboxes_w_h = sorted_one_image_pred_bboxes[:,
                                                              2:] - sorted_one_image_pred_bboxes[:, :
                                                                                                 2]

        sorted_pred_bboxes_areas = sorted_pred_bboxes_w_h[:,
                                                          0] * sorted_pred_bboxes_w_h[:,
                                                                                      1]
        detected_classes = torch.unique(sorted_one_image_classes, sorted=True)

        keep_scores, keep_classes, keep_pred_bboxes = [], [], []
        for detected_class in detected_classes:
            single_class_scores = sorted_one_image_scores[
                sorted_one_image_classes == detected_class]
            single_class_pred_bboxes = sorted_one_image_pred_bboxes[
                sorted_one_image_classes == detected_class]
            single_class_pred_bboxes_areas = sorted_pred_bboxes_areas[
                sorted_one_image_classes == detected_class]
            single_class = sorted_one_image_classes[sorted_one_image_classes ==
                                                    detected_class]

            single_keep_scores,single_keep_classes,single_keep_pred_bboxes=[],[],[]
            while single_class_scores.numel() > 0:
                top1_score, top1_class, top1_pred_bbox = single_class_scores[
                    0:1], single_class[0:1], single_class_pred_bboxes[0:1]

                single_keep_scores.append(top1_score)
                single_keep_classes.append(top1_class)
                single_keep_pred_bboxes.append(top1_pred_bbox)

                top1_areas = single_class_pred_bboxes_areas[0]

                if single_class_scores.numel() == 1:
                    break

                single_class_scores = single_class_scores[1:]
                single_class = single_class[1:]
                single_class_pred_bboxes = single_class_pred_bboxes[1:]
                single_class_pred_bboxes_areas = single_class_pred_bboxes_areas[
                    1:]

                overlap_area_top_left = torch.max(
                    single_class_pred_bboxes[:, :2], top1_pred_bbox[:, :2])
                overlap_area_bot_right = torch.min(
                    single_class_pred_bboxes[:, 2:], top1_pred_bbox[:, 2:])
                overlap_area_sizes = torch.clamp(overlap_area_bot_right -
                                                 overlap_area_top_left,
                                                 min=0)
                overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:,
                                                                             1]

                # compute union_area
                union_area = top1_areas + single_class_pred_bboxes_areas - overlap_area
                union_area = torch.clamp(union_area, min=1e-4)
                # compute ious for top1 pred_bbox and the other pred_bboxes
                ious = overlap_area / union_area

                top1_pred_bbox_ctr = (top1_pred_bbox[:, 2:4] +
                                      top1_pred_bbox[:, 0:2]) / 2
                single_class_pred_bboxes_ctr = (
                    single_class_pred_bboxes[:, 2:4] +
                    single_class_pred_bboxes[:, 0:2]) / 2
                p2 = (top1_pred_bbox_ctr[:, 0] -
                      single_class_pred_bboxes_ctr[:, 0])**2 + (
                          top1_pred_bbox_ctr[:, 1] -
                          single_class_pred_bboxes_ctr[:, 1])**2

                enclose_area_top_left = torch.min(
                    top1_pred_bbox[:, 0:2], single_class_pred_bboxes[:, 0:2])
                enclose_area_bot_right = torch.max(
                    top1_pred_bbox[:, 2:4], single_class_pred_bboxes[:, 2:4])
                enclose_area_sizes = torch.clamp(enclose_area_bot_right -
                                                 enclose_area_top_left,
                                                 min=1e-4)
                c2 = (enclose_area_sizes[:, 0])**2 + (enclose_area_sizes[:,
                                                                         1])**2

                dious = ious - p2 / c2

                single_class_scores = single_class_scores[
                    dious < self.nms_threshold]
                single_class = single_class[dious < self.nms_threshold]
                single_class_pred_bboxes = single_class_pred_bboxes[
                    dious < self.nms_threshold]
                single_class_pred_bboxes_areas = single_class_pred_bboxes_areas[
                    dious < self.nms_threshold]

            single_keep_scores = torch.cat(single_keep_scores, axis=0)
            single_keep_classes = torch.cat(single_keep_classes, axis=0)
            single_keep_pred_bboxes = torch.cat(single_keep_pred_bboxes,
                                                axis=0)

            keep_scores.append(single_keep_scores)
            keep_classes.append(single_keep_classes)
            keep_pred_bboxes.append(single_keep_pred_bboxes)

        keep_scores = torch.cat(keep_scores, axis=0)
        keep_classes = torch.cat(keep_classes, axis=0)
        keep_pred_bboxes = torch.cat(keep_pred_bboxes, axis=0)

        return keep_scores, keep_classes, keep_pred_bboxes

該NMS函數直接替代RetinaDecoder類中的NMS函數即可。根據論文中貼出的實驗結果,替代後nms_threshold不變效果也會更好。

CIoU loss

DIoU loss和CIoU loss都來自這篇文章:https://arxiv.org/pdf/1911.08287.pdf 。CIoU loss在DIoU loss的基礎上增加了一項av(a爲權重係數,v用來度量長寬比的相似性),用來衡量預測框和目標框的長寬比,使得迴歸方向更加合理。同樣,在NMS後處理中,CIoU也可以替換IoU評價策略,使得NMS得到的結果更加合理。
CIoU loss代碼實現如下:

    def compute_one_image_ciou_loss(self, per_image_reg_heads,
                                    per_image_anchors,
                                    per_image_anchors_annotations):
        """
        compute one image ciou loss(reg loss)
        per_image_reg_heads:[anchor_num,4]
        per_image_anchors:[anchor_num,4]
        per_image_anchors_annotations:[anchor_num,5]
        """
        # Filter anchors with gt class=-1, this part of anchor doesn't calculate giou loss
        device = per_image_reg_heads.device
        per_image_reg_heads = per_image_reg_heads[
            per_image_anchors_annotations[:, 4] > 0]
        per_image_anchors = per_image_anchors[
            per_image_anchors_annotations[:, 4] > 0]
        per_image_anchors_annotations = per_image_anchors_annotations[
            per_image_anchors_annotations[:, 4] > 0]
        positive_anchor_num = per_image_anchors_annotations.shape[0]

        if positive_anchor_num == 0:
            return torch.tensor(0.).to(device)

        per_image_anchors_w_h = per_image_anchors[:,
                                                  2:] - per_image_anchors[:, :2]
        per_image_anchors_ctr = per_image_anchors[:, :
                                                  2] + 0.5 * per_image_anchors_w_h

        factor = torch.tensor([[0.1, 0.1, 0.2, 0.2]]).to(device)
        per_image_reg_heads = per_image_reg_heads * factor

        pred_bboxes_wh = torch.exp(
            per_image_reg_heads[:, 2:]) * per_image_anchors_w_h
        pred_bboxes_ctr = per_image_reg_heads[:, :
                                              2] * per_image_anchors_w_h + per_image_anchors_ctr
        pred_bboxes_x_min_y_min = pred_bboxes_ctr - 0.5 * pred_bboxes_wh
        pred_bboxes_x_max_y_max = pred_bboxes_ctr + 0.5 * pred_bboxes_wh
        pred_bboxes = torch.cat(
            [pred_bboxes_x_min_y_min, pred_bboxes_x_max_y_max], axis=1)

        per_image_annotations = per_image_anchors_annotations[:, 0:4] * factor
        annotations_wh = torch.exp(
            per_image_annotations[:, 2:]) * per_image_anchors_w_h
        annotations_ctr = per_image_annotations[:, :
                                                2] * per_image_anchors_w_h + per_image_anchors_ctr
        annotations_x_min_y_min = annotations_ctr - 0.5 * annotations_wh
        annotations_x_max_y_max = annotations_ctr + 0.5 * annotations_wh
        annotations_bboxes = torch.cat(
            [annotations_x_min_y_min, annotations_x_max_y_max], axis=1)

        overlap_area_top_left = torch.max(pred_bboxes[:, 0:2],
                                          annotations_bboxes[:, 0:2])
        overlap_area_bot_right = torch.min(pred_bboxes[:, 2:4],
                                           annotations_bboxes[:, 2:4])
        overlap_area_sizes = torch.clamp(overlap_area_bot_right -
                                         overlap_area_top_left,
                                         min=0)
        overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:, 1]
        # anchors and annotations convert format to [x1,y1,w,h]
        pred_bboxes_w_h = pred_bboxes[:, 2:4] - pred_bboxes[:, 0:2] + 1
        annotations_w_h = annotations_bboxes[:,
                                             2:4] - annotations_bboxes[:,
                                                                       0:2] + 1
        # compute anchors_area and annotations_area
        pred_bboxes_area = pred_bboxes_w_h[:, 0] * pred_bboxes_w_h[:, 1]
        annotations_area = annotations_w_h[:, 0] * annotations_w_h[:, 1]

        # compute union_area
        union_area = pred_bboxes_area + annotations_area - overlap_area
        union_area = torch.clamp(union_area, min=1e-4)
        # compute ious between one image anchors and one image annotations
        ious = overlap_area / union_area

        pred_bboxes_ctr = (pred_bboxes[:, 2:4] + pred_bboxes[:, 0:2]) / 2
        annotations_bboxes_ctr = (annotations_bboxes[:, 2:4] +
                                  annotations_bboxes[:, 0:2]) / 2
        p2 = (pred_bboxes_ctr[:, 0] - annotations_bboxes_ctr[:, 0])**2 + (
            pred_bboxes_ctr[:, 1] - annotations_bboxes_ctr[:, 1])**2

        enclose_area_top_left = torch.min(pred_bboxes[:, 0:2],
                                          annotations_bboxes[:, 0:2])
        enclose_area_bot_right = torch.max(pred_bboxes[:, 2:4],
                                           annotations_bboxes[:, 2:4])
        enclose_area_sizes = torch.clamp(enclose_area_bot_right -
                                         enclose_area_top_left,
                                         min=1e-2)
        c2 = (enclose_area_sizes[:, 0])**2 + (enclose_area_sizes[:, 1])**2

        with torch.no_grad():
            v = torch.pow(
            (torch.atan(annotations_w_h[:, 0] / annotations_w_h[:, 1]) -
             torch.atan(pred_bboxes_w_h[:, 0] / pred_bboxes_w_h[:, 1])),
            2) * (4 / (math.pi**2))
            alpha = v / (1 - ious + v)

        v = torch.pow(
            (torch.atan(annotations_w_h[:, 0] / annotations_w_h[:, 1]) -
             torch.atan(pred_bboxes_w_h[:, 0] / pred_bboxes_w_h[:, 1])),
            2) * (4 / (math.pi**2))
        
        cious_loss = 1. - ious + p2 / c2 + alpha * v
        cious_loss = cious_loss.sum() / positive_anchor_num
        cious_loss = 2. * cious_loss

        return cious_loss

注意alpha作爲權重係數不回傳梯度,最後乘以2只是爲了平衡分類loss和迴歸loss。該loss使用時直接替代RetinaLoss類中compute_one_image_smoothl1_loss函數即可(forward函數for循環時需要額外取每張圖的anchors)。
同時,在論文中還提到可以將DIoU替代NMS中的IoU,後處理效果更好。我也嘗試了一下用CIoU替代NMS中的IoU。
CIoU-NMS代碼實現如下:

    def nms(self, one_image_scores, one_image_classes, one_image_pred_bboxes):
        """
        one_image_scores:[anchor_nums],4:classification predict scores
        one_image_classes:[anchor_nums],class indexes for predict scores
        one_image_pred_bboxes:[anchor_nums,4],4:x_min,y_min,x_max,y_max
        """
        # Sort boxes
        sorted_one_image_scores, sorted_one_image_scores_indexes = torch.sort(
            one_image_scores, descending=True)
        sorted_one_image_classes = one_image_classes[
            sorted_one_image_scores_indexes]
        sorted_one_image_pred_bboxes = one_image_pred_bboxes[
            sorted_one_image_scores_indexes]
        sorted_pred_bboxes_w_h = sorted_one_image_pred_bboxes[:,
                                                              2:] - sorted_one_image_pred_bboxes[:, :
                                                                                                 2]

        sorted_pred_bboxes_areas = sorted_pred_bboxes_w_h[:,
                                                          0] * sorted_pred_bboxes_w_h[:,
                                                                                      1]
        detected_classes = torch.unique(sorted_one_image_classes, sorted=True)

        keep_scores, keep_classes, keep_pred_bboxes = [], [], []
        for detected_class in detected_classes:
            single_class_scores = sorted_one_image_scores[
                sorted_one_image_classes == detected_class]
            single_class_pred_bboxes = sorted_one_image_pred_bboxes[
                sorted_one_image_classes == detected_class]
            single_class_pred_bboxes_areas = sorted_pred_bboxes_areas[
                sorted_one_image_classes == detected_class]
            single_class = sorted_one_image_classes[sorted_one_image_classes ==
                                                    detected_class]

            single_keep_scores,single_keep_classes,single_keep_pred_bboxes=[],[],[]
            while single_class_scores.numel() > 0:
                top1_score, top1_class, top1_pred_bbox = single_class_scores[
                    0:1], single_class[0:1], single_class_pred_bboxes[0:1]

                single_keep_scores.append(top1_score)
                single_keep_classes.append(top1_class)
                single_keep_pred_bboxes.append(top1_pred_bbox)

                top1_areas = single_class_pred_bboxes_areas[0]

                if single_class_scores.numel() == 1:
                    break

                single_class_scores = single_class_scores[1:]
                single_class = single_class[1:]
                single_class_pred_bboxes = single_class_pred_bboxes[1:]
                single_class_pred_bboxes_areas = single_class_pred_bboxes_areas[
                    1:]

                overlap_area_top_left = torch.max(
                    single_class_pred_bboxes[:, :2], top1_pred_bbox[:, :2])
                overlap_area_bot_right = torch.min(
                    single_class_pred_bboxes[:, 2:], top1_pred_bbox[:, 2:])
                overlap_area_sizes = torch.clamp(overlap_area_bot_right -
                                                 overlap_area_top_left,
                                                 min=0)
                overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:,
                                                                             1]

                # compute union_area
                union_area = top1_areas + single_class_pred_bboxes_areas - overlap_area
                union_area = torch.clamp(union_area, min=1e-4)
                # compute ious for top1 pred_bbox and the other pred_bboxes
                ious = overlap_area / union_area

                top1_pred_bbox_ctr = (top1_pred_bbox[:, 2:4] +
                                      top1_pred_bbox[:, 0:2]) / 2
                single_class_pred_bboxes_ctr = (
                    single_class_pred_bboxes[:, 2:4] +
                    single_class_pred_bboxes[:, 0:2]) / 2
                p2 = (top1_pred_bbox_ctr[:, 0] -
                      single_class_pred_bboxes_ctr[:, 0])**2 + (
                          top1_pred_bbox_ctr[:, 1] -
                          single_class_pred_bboxes_ctr[:, 1])**2

                enclose_area_top_left = torch.min(
                    top1_pred_bbox[:, 0:2], single_class_pred_bboxes[:, 0:2])
                enclose_area_bot_right = torch.max(
                    top1_pred_bbox[:, 2:4], single_class_pred_bboxes[:, 2:4])
                enclose_area_sizes = torch.clamp(enclose_area_bot_right -
                                                 enclose_area_top_left,
                                                 min=1e-4)
                c2 = (enclose_area_sizes[:, 0])**2 + (enclose_area_sizes[:,
                                                                         1])**2

                top1_pred_bbox_wh = top1_pred_bbox[:,
                                                   2:4] - top1_pred_bbox[:,
                                                                         0:2]
                single_class_pred_bboxes_wh = single_class_pred_bboxes[:, 2:
                                                                       4] - single_class_pred_bboxes[:,
                                                                                                     0:
                                                                                                     2]
                v = torch.pow(
                    (torch.atan(single_class_pred_bboxes_wh[:, 0] /
                                single_class_pred_bboxes_wh[:, 1]) -
                     torch.atan(top1_pred_bbox_wh[:, 0] /
                                top1_pred_bbox_wh[:, 1])), 2) * (4 /
                                                                 (math.pi**2))
                alpha = v / (1 - ious + v)

                cious = ious - p2 / c2 - alpha * v

                single_class_scores = single_class_scores[
                    cious < self.nms_threshold]
                single_class = single_class[cious < self.nms_threshold]
                single_class_pred_bboxes = single_class_pred_bboxes[
                    cious < self.nms_threshold]
                single_class_pred_bboxes_areas = single_class_pred_bboxes_areas[
                    cious < self.nms_threshold]

            single_keep_scores = torch.cat(single_keep_scores, axis=0)
            single_keep_classes = torch.cat(single_keep_classes, axis=0)
            single_keep_pred_bboxes = torch.cat(single_keep_pred_bboxes,
                                                axis=0)

            keep_scores.append(single_keep_scores)
            keep_classes.append(single_keep_classes)
            keep_pred_bboxes.append(single_keep_pred_bboxes)

        keep_scores = torch.cat(keep_scores, axis=0)
        keep_classes = torch.cat(keep_classes, axis=0)
        keep_pred_bboxes = torch.cat(keep_pred_bboxes, axis=0)

        return keep_scores, keep_classes, keep_pred_bboxes

該NMS函數直接替代RetinaDecoder類中的NMS函數即可。根據論文中貼出的實驗結果,替代後nms_threshold不變效果也會更好。

實驗結果

實驗以從零實現RetinaNet(七)中最後的ResNet50-RetinaNet-aug-iscrowd作爲baseline。輸入分辨率600,約等於RetinaNet論文中的分辨率450。

Network batch gpu-num apex syncbn epoch5-mAP-mAR-loss epoch5-mAP-mAR-loss epoch12-mAP-mAR-loss
baseline 15 1 yes no 0.254,0.394,0.62 0.280,0.418,0.53 0.286,0.421,0.50
baseline-giou 15 1 yes no 0.259,0.400,0.68 0.284,0.425,0.59 0.292,0.429,0.57
baseline-diou 15 1 yes no 0.257,0.399,0.67 0.285,0.422,0.59 0.291,0.424,0.57
baseline-diou-nms 15 1 yes no 0.256,0.395,0.67 0.286,0.424,0.59 0.290,0.428,0.57
baseline-ciou 15 1 yes no 0.263,0.400,0.67 0.287,0.420,0.59 0.292,0.426,0.57
baseline-ciou-nms 15 1 yes no 0.257,0.401,0.67 0.284,0.418,0.59 0.292,0.429,0.57

baseline-giou代表迴歸loss使用GIoU loss。baseline-diou-nms代表迴歸loss使用DIoU loss且NMS中IoU替換爲DIoU。baseline-ciou-nms代表迴歸loss使用CIoU loss且NMS中IoU替換爲CIoU。根據結果來看,使用GIoU、DIoU、CIoU後模型的表現要比使用smooth l1 loss要好,但NMS中替換IoU評價指標爲DIoU和CIoU時沒有明顯漲點。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章