【庖丁解牛】從零實現RetinaNet(五):迴歸預測轉換、NMS後處理、decode解碼

所有代碼已上傳到本人github repository:https://github.com/zgcr/pytorch-ImageNet-CIFAR-COCO-VOC-training
如果覺得有用,請點個star喲!
下列代碼均在pytorch1.4版本中測試過,確認正確無誤。

迴歸預測轉換

模型訓練完成後,需要decode模型輸出才能進行測試。我們從RetinaNet類進行forward計算後可以得到cls heads和reg heads,但此時reg heads預測的是tx,ty,tw,th,我們需要使用對應的Anchor box座標將其轉換爲預測的box座標。座標的轉換規則就是從零實現RetinaNet(四)中box座標轉換爲迴歸標籤tx,ty,tw,th的逆運算。

迴歸預測轉換爲box預測的代碼實現如下:

    def snap_tx_ty_tw_th_reg_heads_to_x1_y1_x2_y2_bboxes(
            self, reg_heads, anchors):
        """
        snap reg heads to pred bboxes
        reg_heads:[anchor_nums,4],4:[tx,ty,tw,th]
        anchors:[anchor_nums,4],4:[x_min,y_min,x_max,y_max]
        """
        anchors_wh = anchors[:, 2:] - anchors[:, :2]
        anchors_ctr = anchors[:, :2] + 0.5 * anchors_wh

        device = anchors.device
        factor = torch.tensor([[0.1, 0.1, 0.2, 0.2]]).to(device)

        reg_heads = reg_heads * factor

        pred_bboxes_wh = torch.exp(reg_heads[:, 2:]) * anchors_wh
        pred_bboxes_ctr = reg_heads[:, :2] * anchors_wh + anchors_ctr

        pred_bboxes_x_min_y_min = pred_bboxes_ctr - 0.5 * pred_bboxes_wh
        pred_bboxes_x_max_y_max = pred_bboxes_ctr + 0.5 * pred_bboxes_wh

        pred_bboxes = torch.cat(
            [pred_bboxes_x_min_y_min, pred_bboxes_x_max_y_max], axis=1)
        pred_bboxes = pred_bboxes.int()

        pred_bboxes[:, 0] = torch.clamp(pred_bboxes[:, 0], min=0)
        pred_bboxes[:, 1] = torch.clamp(pred_bboxes[:, 1], min=0)
        pred_bboxes[:, 2] = torch.clamp(pred_bboxes[:, 2],
                                        max=self.image_w - 1)
        pred_bboxes[:, 3] = torch.clamp(pred_bboxes[:, 3],
                                        max=self.image_h - 1)

        # pred bboxes shape:[anchor_nums,4]
        return pred_bboxes

NMS後處理

NMS後處理的標準方法是:先將所有候選目標按分類score從大到小排序,記錄所有候選目標的分類類別有哪幾種。然後開始遍歷探測到的這幾個類別,對於每個類別,提取出這個類別的所有候選目標(注意因爲我們一開始已經排過序了,所以按類別提取出來仍然是有序的),先把第一個目標提取到保留目標集合中,然後計算剩餘所有目標與該目標的IoU,IoU大於閾值的候選目標全部拋棄。對於RetinaNet,這個閾值爲0.5。然後剩餘沒有拋棄的目標重複上面過程,繼續把第一個目標提取到保留目標集合中,後面操作都是一樣的,直到沒有候選目標爲止,對該類候選目標的NMS就做完了。對所有類別都遍歷完,NMS就做完了。
在其他目標檢測代碼實現中,我發現有許多代碼在做NMS後處理時並沒有分類別來作NMS(即所有不同類別的候選目標一起作NMS)。因此我也嘗試了這種做法,發現這種做法總是比NMS的標準做法要低0.2~0.5個mAP左右,因此,在下面的代碼實現中,還是使用NMS的標準方法。

NMS後處理的代碼實現如下:

    def nms(self, one_image_scores, one_image_classes, one_image_pred_bboxes):
        """
        one_image_scores:[anchor_nums],4:classification predict scores
        one_image_classes:[anchor_nums],class indexes for predict scores
        one_image_pred_bboxes:[anchor_nums,4],4:x_min,y_min,x_max,y_max
        """
        # Sort boxes
        sorted_one_image_scores, sorted_one_image_scores_indexes = torch.sort(
            one_image_scores, descending=True)
        sorted_one_image_classes = one_image_classes[
            sorted_one_image_scores_indexes]
        sorted_one_image_pred_bboxes = one_image_pred_bboxes[
            sorted_one_image_scores_indexes]
        sorted_pred_bboxes_w_h = sorted_one_image_pred_bboxes[:,
                                                              2:] - sorted_one_image_pred_bboxes[:, :
                                                                                                 2]

        sorted_pred_bboxes_areas = sorted_pred_bboxes_w_h[:,
                                                          0] * sorted_pred_bboxes_w_h[:,
                                                                                      1]
        detected_classes = torch.unique(sorted_one_image_classes, sorted=True)

        keep_scores, keep_classes, keep_pred_bboxes = [], [], []
        for detected_class in detected_classes:
            single_class_scores = sorted_one_image_scores[
                sorted_one_image_classes == detected_class]
            single_class_pred_bboxes = sorted_one_image_pred_bboxes[
                sorted_one_image_classes == detected_class]
            single_class_pred_bboxes_areas = sorted_pred_bboxes_areas[
                sorted_one_image_classes == detected_class]
            single_class = sorted_one_image_classes[sorted_one_image_classes ==
                                                    detected_class]

            single_keep_scores,single_keep_classes,single_keep_pred_bboxes=[],[],[]
            while single_class_scores.numel() > 0:
                top1_score, top1_class, top1_pred_bbox = single_class_scores[
                    0:1], single_class[0:1], single_class_pred_bboxes[0:1]

                single_keep_scores.append(top1_score)
                single_keep_classes.append(top1_class)
                single_keep_pred_bboxes.append(top1_pred_bbox)

                top1_areas = single_class_pred_bboxes_areas[0]

                if single_class_scores.numel() == 1:
                    break

                single_class_scores = single_class_scores[1:]
                single_class = single_class[1:]
                single_class_pred_bboxes = single_class_pred_bboxes[1:]
                single_class_pred_bboxes_areas = single_class_pred_bboxes_areas[
                    1:]

                overlap_area_top_left = torch.max(
                    single_class_pred_bboxes[:, :2], top1_pred_bbox[:, :2])
                overlap_area_bot_right = torch.min(
                    single_class_pred_bboxes[:, 2:], top1_pred_bbox[:, 2:])
                overlap_area_sizes = torch.clamp(overlap_area_bot_right -
                                                 overlap_area_top_left,
                                                 min=0)
                overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:,
                                                                             1]

                # compute union_area
                union_area = top1_areas + single_class_pred_bboxes_areas - overlap_area
                union_area = torch.clamp(union_area, min=1e-4)
                # compute ious for top1 pred_bbox and the other pred_bboxes
                ious = overlap_area / union_area

                single_class_scores = single_class_scores[
                    ious < self.nms_threshold]
                single_class = single_class[ious < self.nms_threshold]
                single_class_pred_bboxes = single_class_pred_bboxes[
                    ious < self.nms_threshold]
                single_class_pred_bboxes_areas = single_class_pred_bboxes_areas[
                    ious < self.nms_threshold]

            single_keep_scores = torch.cat(single_keep_scores, axis=0)
            single_keep_classes = torch.cat(single_keep_classes, axis=0)
            single_keep_pred_bboxes = torch.cat(single_keep_pred_bboxes,
                                                axis=0)

            keep_scores.append(single_keep_scores)
            keep_classes.append(single_keep_classes)
            keep_pred_bboxes.append(single_keep_pred_bboxes)

        keep_scores = torch.cat(keep_scores, axis=0)
        keep_classes = torch.cat(keep_classes, axis=0)
        keep_pred_bboxes = torch.cat(keep_pred_bboxes, axis=0)

        return keep_scores, keep_classes, keep_pred_bboxes

decode解碼

有了上面兩部分,現在我們可以開始decode解碼了。整個decode解碼的流程是:先將reg head的tx,ty,tw,th預測轉換爲box座標預測(需要使用Anchor座標信息),然後使用一個分類score閾值過濾到分類分數太低的候選目標,對於RetinaNet,這個閾值是0.05。然後,我們對剩下的候選目標NMS後處理,得到保留的候選目標。最後,我們還設置了一個max_detection_num,即確定最終輸出時保留多少個目標,對於COCO數據集,這個值爲100,因爲COCO數據集的圖片上沒有單張圖片標註了超過100個目標的情況。
decode解碼的代碼實現如下:

class RetinaDecoder(nn.Module):
    def __init__(self,
                 image_w,
                 image_h,
                 min_score_threshold=0.05,
                 nms_threshold=0.5,
                 max_detection_num=100):
        super(RetinaDecoder, self).__init__()
        self.image_w = image_w
        self.image_h = image_h
        self.min_score_threshold = min_score_threshold
        self.nms_threshold = nms_threshold
        self.max_detection_num = max_detection_num

    def forward(self, cls_heads, reg_heads, batch_anchors):
        device = cls_heads[0].device
        cls_heads = torch.cat(cls_heads, axis=1)
        reg_heads = torch.cat(reg_heads, axis=1)
        batch_anchors = torch.cat(batch_anchors, axis=1)

        batch_scores, batch_classes, batch_pred_bboxes = [], [], []
        for per_image_cls_heads, per_image_reg_heads, per_image_anchors in zip(
                cls_heads, reg_heads, batch_anchors):
            pred_bboxes = self.snap_tx_ty_tw_th_reg_heads_to_x1_y1_x2_y2_bboxes(
                per_image_reg_heads, per_image_anchors)
            scores, score_classes = torch.max(per_image_cls_heads, dim=1)
            score_classes = score_classes[
                scores > self.min_score_threshold].float()
            pred_bboxes = pred_bboxes[
                scores > self.min_score_threshold].float()
            scores = scores[scores > self.min_score_threshold].float()

            single_image_scores = (-1) * torch.ones(
                (self.max_detection_num, ), device=device)
            single_image_classes = (-1) * torch.ones(
                (self.max_detection_num, ), device=device)
            single_image_pred_bboxes = (-1) * torch.ones(
                (self.max_detection_num, 4), device=device)

            if scores.shape[0] != 0:
                scores, score_classes, pred_bboxes = self.nms(
                    scores, score_classes, pred_bboxes)

                sorted_keep_scores, sorted_keep_scores_indexes = torch.sort(
                    scores, descending=True)
                sorted_keep_classes = score_classes[sorted_keep_scores_indexes]
                sorted_keep_pred_bboxes = pred_bboxes[
                    sorted_keep_scores_indexes]

                final_detection_num = min(self.max_detection_num,
                                          sorted_keep_scores.shape[0])

                single_image_scores[
                    0:final_detection_num] = sorted_keep_scores[
                        0:final_detection_num]
                single_image_classes[
                    0:final_detection_num] = sorted_keep_classes[
                        0:final_detection_num]
                single_image_pred_bboxes[
                    0:final_detection_num, :] = sorted_keep_pred_bboxes[
                        0:final_detection_num, :]

            single_image_scores = single_image_scores.unsqueeze(0)
            single_image_classes = single_image_classes.unsqueeze(0)
            single_image_pred_bboxes = single_image_pred_bboxes.unsqueeze(0)

            batch_scores.append(single_image_scores)
            batch_classes.append(single_image_classes)
            batch_pred_bboxes.append(single_image_pred_bboxes)

        batch_scores = torch.cat(batch_scores, axis=0)
        batch_classes = torch.cat(batch_classes, axis=0)
        batch_pred_bboxes = torch.cat(batch_pred_bboxes, axis=0)

        # batch_scores shape:[batch_size,max_detection_num]
        # batch_classes shape:[batch_size,max_detection_num]
        # batch_pred_bboxes shape[batch_size,max_detection_num,4]
        return batch_scores, batch_classes, batch_pred_bboxes

    def nms(self, one_image_scores, one_image_classes, one_image_pred_bboxes):
        """
        one_image_scores:[anchor_nums],4:classification predict scores
        one_image_classes:[anchor_nums],class indexes for predict scores
        one_image_pred_bboxes:[anchor_nums,4],4:x_min,y_min,x_max,y_max
        """
        # Sort boxes
        sorted_one_image_scores, sorted_one_image_scores_indexes = torch.sort(
            one_image_scores, descending=True)
        sorted_one_image_classes = one_image_classes[
            sorted_one_image_scores_indexes]
        sorted_one_image_pred_bboxes = one_image_pred_bboxes[
            sorted_one_image_scores_indexes]
        sorted_pred_bboxes_w_h = sorted_one_image_pred_bboxes[:,
                                                              2:] - sorted_one_image_pred_bboxes[:, :
                                                                                                 2]

        sorted_pred_bboxes_areas = sorted_pred_bboxes_w_h[:,
                                                          0] * sorted_pred_bboxes_w_h[:,
                                                                                      1]
        detected_classes = torch.unique(sorted_one_image_classes, sorted=True)

        keep_scores, keep_classes, keep_pred_bboxes = [], [], []
        for detected_class in detected_classes:
            single_class_scores = sorted_one_image_scores[
                sorted_one_image_classes == detected_class]
            single_class_pred_bboxes = sorted_one_image_pred_bboxes[
                sorted_one_image_classes == detected_class]
            single_class_pred_bboxes_areas = sorted_pred_bboxes_areas[
                sorted_one_image_classes == detected_class]
            single_class = sorted_one_image_classes[sorted_one_image_classes ==
                                                    detected_class]

            single_keep_scores,single_keep_classes,single_keep_pred_bboxes=[],[],[]
            while single_class_scores.numel() > 0:
                top1_score, top1_class, top1_pred_bbox = single_class_scores[
                    0:1], single_class[0:1], single_class_pred_bboxes[0:1]

                single_keep_scores.append(top1_score)
                single_keep_classes.append(top1_class)
                single_keep_pred_bboxes.append(top1_pred_bbox)

                top1_areas = single_class_pred_bboxes_areas[0]

                if single_class_scores.numel() == 1:
                    break

                single_class_scores = single_class_scores[1:]
                single_class = single_class[1:]
                single_class_pred_bboxes = single_class_pred_bboxes[1:]
                single_class_pred_bboxes_areas = single_class_pred_bboxes_areas[
                    1:]

                overlap_area_top_left = torch.max(
                    single_class_pred_bboxes[:, :2], top1_pred_bbox[:, :2])
                overlap_area_bot_right = torch.min(
                    single_class_pred_bboxes[:, 2:], top1_pred_bbox[:, 2:])
                overlap_area_sizes = torch.clamp(overlap_area_bot_right -
                                                 overlap_area_top_left,
                                                 min=0)
                overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:,
                                                                             1]

                # compute union_area
                union_area = top1_areas + single_class_pred_bboxes_areas - overlap_area
                union_area = torch.clamp(union_area, min=1e-4)
                # compute ious for top1 pred_bbox and the other pred_bboxes
                ious = overlap_area / union_area

                single_class_scores = single_class_scores[
                    ious < self.nms_threshold]
                single_class = single_class[ious < self.nms_threshold]
                single_class_pred_bboxes = single_class_pred_bboxes[
                    ious < self.nms_threshold]
                single_class_pred_bboxes_areas = single_class_pred_bboxes_areas[
                    ious < self.nms_threshold]

            single_keep_scores = torch.cat(single_keep_scores, axis=0)
            single_keep_classes = torch.cat(single_keep_classes, axis=0)
            single_keep_pred_bboxes = torch.cat(single_keep_pred_bboxes,
                                                axis=0)

            keep_scores.append(single_keep_scores)
            keep_classes.append(single_keep_classes)
            keep_pred_bboxes.append(single_keep_pred_bboxes)

        keep_scores = torch.cat(keep_scores, axis=0)
        keep_classes = torch.cat(keep_classes, axis=0)
        keep_pred_bboxes = torch.cat(keep_pred_bboxes, axis=0)

        return keep_scores, keep_classes, keep_pred_bboxes

    def snap_tx_ty_tw_th_reg_heads_to_x1_y1_x2_y2_bboxes(
            self, reg_heads, anchors):
        """
        snap reg heads to pred bboxes
        reg_heads:[anchor_nums,4],4:[tx,ty,tw,th]
        anchors:[anchor_nums,4],4:[x_min,y_min,x_max,y_max]
        """
        anchors_wh = anchors[:, 2:] - anchors[:, :2]
        anchors_ctr = anchors[:, :2] + 0.5 * anchors_wh

        device = anchors.device
        factor = torch.tensor([[0.1, 0.1, 0.2, 0.2]]).to(device)

        reg_heads = reg_heads * factor

        pred_bboxes_wh = torch.exp(reg_heads[:, 2:]) * anchors_wh
        pred_bboxes_ctr = reg_heads[:, :2] * anchors_wh + anchors_ctr

        pred_bboxes_x_min_y_min = pred_bboxes_ctr - 0.5 * pred_bboxes_wh
        pred_bboxes_x_max_y_max = pred_bboxes_ctr + 0.5 * pred_bboxes_wh

        pred_bboxes = torch.cat(
            [pred_bboxes_x_min_y_min, pred_bboxes_x_max_y_max], axis=1)
        pred_bboxes = pred_bboxes.int()

        pred_bboxes[:, 0] = torch.clamp(pred_bboxes[:, 0], min=0)
        pred_bboxes[:, 1] = torch.clamp(pred_bboxes[:, 1], min=0)
        pred_bboxes[:, 2] = torch.clamp(pred_bboxes[:, 2],
                                        max=self.image_w - 1)
        pred_bboxes[:, 3] = torch.clamp(pred_bboxes[:, 3],
                                        max=self.image_h - 1)

        # pred bboxes shape:[anchor_nums,4]
        return pred_bboxes

這樣decode解碼部分就實現好了。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章