faster rcnn代碼解讀(八)rcnn_proposal_target_gen

faster rcnn代碼解讀參考

https://github.com/adityaarun1/pytorch_fast-er_rcnn

    https://github.com/jwyang/faster-rcnn.pytorch

之前rpn的anchor生成和target以及loss都有了,rpn環節以及是完整的了,下面就是rcnn環節。rcnn的輸入其實就是rpn的輸出。

class rcnn_target_layer(nn.Module):
    """
    Assign object detection proposals to ground-truth targets. Produces proposal
    classification labels and bounding-box regression targets.
    """

    def __init__(self,nclasses):
        super(rcnn_target_layer, self).__init__()
        self.bbox_normalize_means = torch.FloatTensor(cfg['bbox_normalize_means'])
        self.bbox_normalize_stds = torch.FloatTensor(cfg['bbox_normalize_stds'])
        self.bbox_inside_weights = torch.FloatTensor(cfg['bbox_inside_weights'])
        self.nclasses = nclasses

        # self.bbox_normalize_means = self.bbox_normalize_means.type_as(gt_boxes)
        # self.bbox_normalize_stds = self.bbox_normalize_stds.type_as(gt_boxes)
        # self.bbox_inside_weights = self.bbox_inside_weights.type_as(gt_boxes)
    def forward(self, rpn_proposal, gt_boxes):
        '''
        :param rpn_proposal:         [batch , post_nms_topN, 5] #有分數全0部分
        :param gt_boxes:             [batch,         gt_num ,5] ( x1, y1, x2, y2, label)
        :return:
          Assign object detection proposals to ground-truth targets. Produces proposal
          classification labels and bounding-box regression targets.
        # Proposal ROIs (0, x1, y1, x2, y2) coming from RPN
        # (i.e., rpn.proposal_layer.ProposalLayer), or any other source
        '''
        all_rois = rpn_proposal          # [batch , post_nms_topN, 5]
        if cfg['use_gt']:
            gt_boxes_append = gt_boxes.new(gt_boxes.size()).zero_()
            gt_boxes_append[:, :, 1:5] = gt_boxes[:, :, :4]# 將gt第一列至0表示類別 (0 ,  x1, y1, x2, y2)
            gt_boxes_append =gt_boxes_append.to(all_rois.device)
            # Include ground-truth boxes in the set of candidate rois
            all_rois = torch.cat([all_rois, gt_boxes_append], 1)       

        num_images = gt_boxes.shape[0]# 圖像大小實際上batchsize大小,這裏就是1了
        rois_per_image = int(cfg['batch_size'] / num_images)# rcnn總的批量大小/圖像數,也就是每張圖實際上可以計算的rois的個數
        fg_rois_per_image = int(np.round(cfg['fg_fraction'] * rois_per_image))#每張圖中前景個數
        fg_rois_per_image = 1 if fg_rois_per_image == 0 else fg_rois_per_image#至少有一個前景

        # Sample rois with classification labels and bounding box regression  targets
        '''
        in:
            all_rois:   [batch , post_nms_topN + gt_num,5]  (  0, x1, y1, x2,    y2)
            all_scores: [batch , post_nms_topN + gt_num,1 ]
            gt_boxes:   [batch ,                  gt_num ,5] ( x1, y1, x2, y2, label)
            fg_rois_per_image: fg_num
            rois_per_image:    rois_num
            self.nclasses:      nclasses
        out:
            rcnn_rois                : [batch, rois_per_image, 5]
            rcnn_labels              : [batch, rois_per_image, 1]
            bbox_targets             : [batch, rois_per_image, 4]
            rcnn_bbox_inside_weights : [batch, rois_per_image, 4]
            rcnn_bbox_outside_weights: [batch, rois_per_image, 4]
        '''

        rcnn_rois,rcnn_labels , rcnn_bbox_targets, rcnn_bbox_inside_weights,rcnn_bbox_outside_weights \
            = self._sample_rois( all_rois, gt_boxes, fg_rois_per_image, rois_per_image, self.nclasses)



        return rcnn_rois, rcnn_labels, rcnn_bbox_targets, rcnn_bbox_inside_weights, rcnn_bbox_outside_weights




    def backward(self, top, propagate_down, bottom):
        """This layer does not propagate gradients."""
        pass

    def reshape(self, bottom, top):
        """Reshaping happens during the call to forward."""
        pass

    def _sample_rois(self, all_rois,  gt_boxes, fg_rois_per_image,
                     rois_per_image, num_classes):
        '''
        :param all_rois:   [batch ,post_nms_topN + gt_num,5](  0, x1, y1, x2,    y2)
        :param gt_boxes:   [batch, gt_num ,5] ( x1, y1, x2, y2, label)
        :param fg_rois_per_image:fg_num
        :param rois_per_image:   rois_num
        :param num_classes:      nclasses
        :return:
        Generate a random sample of RoIs comprising foreground and background
        examples.
      '''
        # overlaps: (rois x gt_boxes)
        # overlaps = bbox_overlaps_batch(all_rois[:, 1:5].data, gt_boxes[:, :4].data)
        # [batch, post_nms_topN + gt_num, 5]、[batch, gt_num ,5] ( x1, y1, x2, y2, label)
        # ->  [batch, post_nms_topN + gt_num, gt_num]
        gt_boxes = gt_boxes.to(all_rois.device)
        # [batch ,post_nms_topN + gt_num,gt_num]
        overlaps = bbox_overlaps_batch( all_rois, gt_boxes)#計算生成的anchor和gt_boxes的overlap

        max_overlaps, gt_assignment = torch.max(overlaps, 2)  # 取出overlap在axis=2維最大值及對應的索引

        batch_size = overlaps.size(0)
        num_proposal = overlaps.size(1)      # post_nms_topN + gt_num
        num_boxes_per_img = overlaps.size(2) # gt_num

        offset = torch.arange(0, batch_size) * gt_boxes.size(1)
        offset = offset.view(-1, 1).type_as(gt_assignment) + gt_assignment

        # labels = gt_boxes[:,:,4].contiguous().view(-1).index((offset.view(-1),)).view(batch_size, -1)
        # [batch, gt_num, 5](x1, y1, x2, y2, label)->[batch, post_nms_topN + gt_num ]
        labels = gt_boxes[:, :, 4].contiguous().view(-1)[(offset.view(-1),)].view(batch_size, -1)#將label縱向排列

        labels_batch = labels.new(batch_size, rois_per_image).zero_()      # [batch_size, rois_per_image]
        rois_batch = all_rois.new(batch_size, rois_per_image, 5).zero_()   # [batch_size, rois_per_image, 5]
        gt_rois_batch = all_rois.new(batch_size, rois_per_image, 5).zero_()# [batch_size, rois_per_image. 5]

        for i in range(batch_size):
            # Select foreground RoIs as those with >= FG_THRESH overlap
            fg_inds = (max_overlaps[i] >= cfg['fg_thresh']).nonzero().view(-1) # 前景索引
            # Guard against the case when an image has fewer than fg_rois_per_image
            # Select background RoIs as those within [BG_THRESH_LO, BG_THRESH_HI)
            bg_inds = ((max_overlaps[i] < cfg['bg_thresh_hi'])&(max_overlaps[i] >= cfg['bg_thresh_lo'])).nonzero().view(-1)#背景索引
            # Small modification to the original version where we ensure a fixed number of regions are sampled
            if fg_inds.numel() > 0 and bg_inds.numel() > 0:
                fg_rois_per_image = min(fg_rois_per_image, fg_inds.numel())#校準一下前景個數
                # torch.randperm seems has a bug on multi-gpu setting that cause the segfault.
                rand_num = torch.from_numpy(np.random.permutation(fg_inds.numel())).type_as(gt_boxes).long()
                fg_inds = fg_inds[rand_num[:fg_rois_per_image]]
                # fg_inds = fg_inds[torch.from_numpy(npr.choice( np.arange(0, fg_inds.numel()), size=int(fg_rois_per_image),
                #                                                replace=False)).long().to(gt_boxes.device)]
                # sampling bg
                bg_rois_per_image = rois_per_image - fg_rois_per_image
                # to_replace = bg_inds.numel() < bg_rois_per_image
                # bg_inds = bg_inds[torch.from_numpy( npr.choice(np.arange(0, bg_inds.numel()),  size=int(bg_rois_per_image),
                #                                                replace=to_replace)).long().to(gt_boxes.device)]
                # Seems torch.rand has a bug, it will generate very large number and make an error.
                # We use numpy rand instead.
                # rand_num = (torch.rand(bg_rois_per_this_image) * bg_num_rois).long().cuda()
                rand_num = np.floor(np.random.rand(bg_rois_per_image) * bg_inds.numel())
                rand_num = torch.from_numpy(rand_num).type_as(gt_boxes).long()
                bg_inds = bg_inds[rand_num]
            elif fg_inds.numel() > 0:
                # sampling fg
                # to_replace = fg_inds.numel() < rois_per_image
                # fg_inds = fg_inds[torch.from_numpy( npr.choice( np.arange(0, fg_inds.numel()), size=int(rois_per_image),
                #                                                 replace=to_replace)).long().to(gt_boxes.device)]
                # fg_rois_per_image = rois_per_image
                # bg_rois_per_this_image = 0
                # print('fg',True)
                # sampling fg
                # rand_num = torch.floor(torch.rand(rois_per_image) * fg_num_rois).long().cuda()
                rand_num = np.floor(np.random.rand(rois_per_image) * fg_inds.numel())
                rand_num = torch.from_numpy(rand_num).type_as(gt_boxes).long()
                fg_inds = fg_inds[rand_num]
                fg_rois_per_this_image = rois_per_image
                bg_rois_per_this_image = 0
            elif bg_inds.numel() > 0:
                # to_replace = bg_inds.numel() < rois_per_image
                # bg_inds = bg_inds[torch.from_numpy(npr.choice(np.arange(0, bg_inds.numel()),size=int(rois_per_image),
                #                                               replace=to_replace)).long().to(gt_boxes.device)]
                # fg_rois_per_image = 0
                # bg_rois_per_this_image = rois_per_image
                # sampling bg
                # rand_num = torch.floor(torch.rand(rois_per_image) * bg_num_rois).long().cuda()
                rand_num = np.floor(np.random.rand(rois_per_image) *  bg_inds.numel())
                rand_num = torch.from_numpy(rand_num).type_as(gt_boxes).long()

                bg_inds = bg_inds[rand_num]
                bg_rois_per_this_image = rois_per_image
                fg_rois_per_this_image = 0

            else:
                raise ValueError("bg_num_rois = 0 and fg_num_rois = 0, this should not happen!")
                import pdb
                pdb.set_trace()

            # The indices that we're selecting (both fg and bg)
            keep_inds = torch.cat([fg_inds, bg_inds], 0)#選取的索引
            # Select sampled values from various arrays:
            labels_batch[i].copy_(labels[i][keep_inds])#選出對應標籤
            # Clamp labels for the background RoIs to 0
            if fg_rois_per_image < rois_per_image:
                labels_batch[i][fg_rois_per_image:] = 0#將前景之後的都設置爲背景
            rois_batch[i] = all_rois[i][keep_inds]
            rois_batch[i, :, 0] = i
            gt_rois_batch[i] = gt_boxes[i][gt_assignment[i][keep_inds]]
        bbox_target_data = self._compute_targets( rois_batch[:,:, 1:5], gt_rois_batch[:,:, :4],labels_batch)#計算目標
        bbox_targets, bbox_inside_weights = self._get_bbox_regression_labels(bbox_target_data,  num_classes)
        bbox_outside_weights = (bbox_inside_weights > 0).float()


        '''
        rois_batch         : [batch, rois_per_image, 5]
        labels_batch       : [batch, rois_per_image, 1]
        bbox_targets       : [batch, rois_per_image, 4]
        bbox_inside_weights: [batch, rois_per_image, 4]
        bbox_inside_weights: [batch, rois_per_image, 4]
        '''
        return rois_batch, labels_batch,  bbox_targets, bbox_inside_weights,bbox_outside_weights

    def _compute_targets(self,ex_rois, gt_rois,labels_batch):
        """Compute bounding-box regression targets for an image."""
        # Inputs are tensor

        assert ex_rois.shape[1] == gt_rois.shape[1]
        assert ex_rois.shape[2] == 4
        assert gt_rois.shape[2] == 4

        batch_size = ex_rois.size(0)
        rois_per_image = ex_rois.size(1)

        targets = bbox_transform_batch(ex_rois, gt_rois)
        # print(targets.is_cuda,self.bbox_normalize_means.is_cuda)
        self.bbox_normalize_means = self.bbox_normalize_means.to(targets.device)
        self.bbox_normalize_stds = self.bbox_normalize_stds.to(targets.device)
        if cfg['bbox_normalize_targets_precomputed']:
            # Optionally normalize targets by a precomputed mean and stdev
            targets = ((targets - self.bbox_normalize_means.expand_as(targets)) / self.bbox_normalize_stds.expand_as(targets))
        return torch.cat([labels_batch.unsqueeze(2), targets], 2)

    def _get_bbox_regression_labels(self, bbox_target_data, num_classes):
        """Bounding-box regression targets (bbox_target_data) are stored in a
      compact form b x N x (class, tx, ty, tw, th)

      This function expands those targets into the 4-of-4*K representation used
      by the network (i.e. only one class has non-zero targets).

      Returns:
          bbox_target (ndarray): b x N x 4K blob of regression targets
          bbox_inside_weights (ndarray): b x N x 4K blob of loss weights
      """
        # Inputs are tensor
        batch_size = bbox_target_data.size(0)
        rois_per_image = bbox_target_data.size(1)
        clss = bbox_target_data[:,:,0]
        bbox_targets = bbox_target_data.new_zeros(batch_size,rois_per_image, 4*num_classes )
        bbox_inside_weights = bbox_target_data.new(bbox_targets.size()).zero_()
        for b in range(batch_size):
            if clss[b].sum() == 0: continue
            inds = torch.nonzero(clss[b] > 0).view(-1)
            if inds.numel()>0:
                temp_clss = clss[b][inds].contiguous().view(-1, 1)
                dim1_inds = inds.unsqueeze(1).expand(inds.size(0), 4)
                dim2_inds = torch.cat([4 * temp_clss, 4 * temp_clss + 1, 4 * temp_clss + 2, 4 * temp_clss + 3], 1).long()
                bbox_targets[b, dim1_inds, dim2_inds] = bbox_target_data[b, inds, 1:]
                bbox_inside_weights[b, dim1_inds, dim2_inds] = self.bbox_inside_weights.to(bbox_inside_weights.device)
        return bbox_targets, bbox_inside_weights

一、代碼解讀

輸入

rpn_proposal: [batch , post_nms_topN, 5] #這是之前rpn_proposal_gen中的部分,也許個數是不夠post_nms_topN的,因此裏面有幾個是沒有意義的,另外這裏的label是-1,0,1在dim=0

param gt_boxes: [batch, gt_num ,5] ( x1, y1, x2, y2, label),這裏的gt_label是真實的,dim=4

  • 現將rpn_proposal後面拼接gt_boxes,拼接部分的dim=0爲0
  • 然後根據rcnn裏面batch_size計算每張圖應該有多少個rois_per_image,然後就是每張圖中前景個數fg_rois_per_image。
  • 對拼接的roi進行採樣

二、roi採樣

 def _sample_rois(self, all_rois,  gt_boxes, fg_rois_per_image,
                     rois_per_image, num_classes):
        '''
        :param all_rois:   [batch ,post_nms_topN + gt_num,5](  0, x1, y1, x2,    y2)
        :param gt_boxes:   [batch, gt_num ,5] ( x1, y1, x2, y2, label)
        :param fg_rois_per_image:fg_num
        :param rois_per_image:   rois_num
        :param num_classes:      nclasses
        :return:
        Generate a random sample of RoIs comprising foreground and background
        examples.
      '''
        # overlaps: (rois x gt_boxes)
        # overlaps = bbox_overlaps_batch(all_rois[:, 1:5].data, gt_boxes[:, :4].data)
        # [batch, post_nms_topN + gt_num, 5]、[batch, gt_num ,5] ( x1, y1, x2, y2, label)
        # ->  [batch, post_nms_topN + gt_num, gt_num]
        gt_boxes = gt_boxes.to(all_rois.device)
        # [batch ,post_nms_topN + gt_num,gt_num]
        overlaps = bbox_overlaps_batch( all_rois, gt_boxes)#計算生成的anchor和gt_boxes的overlap

        max_overlaps, gt_assignment = torch.max(overlaps, 2)  # 取出overlap在axis=2維最大值及對應的索引

        batch_size = overlaps.size(0)
        num_proposal = overlaps.size(1)      # post_nms_topN + gt_num
        num_boxes_per_img = overlaps.size(2) # gt_num

        offset = torch.arange(0, batch_size) * gt_boxes.size(1)
        offset = offset.view(-1, 1).type_as(gt_assignment) + gt_assignment

        # labels = gt_boxes[:,:,4].contiguous().view(-1).index((offset.view(-1),)).view(batch_size, -1)
        # [batch, gt_num, 5](x1, y1, x2, y2, label)->[batch, post_nms_topN + gt_num ]
        labels = gt_boxes[:, :, 4].contiguous().view(-1)[(offset.view(-1),)].view(batch_size, -1)#將label縱向排列

        labels_batch = labels.new(batch_size, rois_per_image).zero_()      # [batch_size, rois_per_image]
        rois_batch = all_rois.new(batch_size, rois_per_image, 5).zero_()   # [batch_size, rois_per_image, 5]
        gt_rois_batch = all_rois.new(batch_size, rois_per_image, 5).zero_()# [batch_size, rois_per_image. 5]

        for i in range(batch_size):
            # Select foreground RoIs as those with >= FG_THRESH overlap
            fg_inds = (max_overlaps[i] >= cfg['fg_thresh']).nonzero().view(-1) # 前景索引
            # Guard against the case when an image has fewer than fg_rois_per_image
            # Select background RoIs as those within [BG_THRESH_LO, BG_THRESH_HI)
            bg_inds = ((max_overlaps[i] < cfg['bg_thresh_hi'])&(max_overlaps[i] >= cfg['bg_thresh_lo'])).nonzero().view(-1)#背景索引
            # Small modification to the original version where we ensure a fixed number of regions are sampled
            if fg_inds.numel() > 0 and bg_inds.numel() > 0:
                fg_rois_per_image = min(fg_rois_per_image, fg_inds.numel())#校準一下前景個數
                # torch.randperm seems has a bug on multi-gpu setting that cause the segfault.
                rand_num = torch.from_numpy(np.random.permutation(fg_inds.numel())).type_as(gt_boxes).long()
                fg_inds = fg_inds[rand_num[:fg_rois_per_image]]
                # fg_inds = fg_inds[torch.from_numpy(npr.choice( np.arange(0, fg_inds.numel()), size=int(fg_rois_per_image),
                #                                                replace=False)).long().to(gt_boxes.device)]
                # sampling bg
                bg_rois_per_image = rois_per_image - fg_rois_per_image
                # to_replace = bg_inds.numel() < bg_rois_per_image
                # bg_inds = bg_inds[torch.from_numpy( npr.choice(np.arange(0, bg_inds.numel()),  size=int(bg_rois_per_image),
                #                                                replace=to_replace)).long().to(gt_boxes.device)]
                # Seems torch.rand has a bug, it will generate very large number and make an error.
                # We use numpy rand instead.
                # rand_num = (torch.rand(bg_rois_per_this_image) * bg_num_rois).long().cuda()
                rand_num = np.floor(np.random.rand(bg_rois_per_image) * bg_inds.numel())
                rand_num = torch.from_numpy(rand_num).type_as(gt_boxes).long()
                bg_inds = bg_inds[rand_num]
            elif fg_inds.numel() > 0:
                # sampling fg
                # to_replace = fg_inds.numel() < rois_per_image
                # fg_inds = fg_inds[torch.from_numpy( npr.choice( np.arange(0, fg_inds.numel()), size=int(rois_per_image),
                #                                                 replace=to_replace)).long().to(gt_boxes.device)]
                # fg_rois_per_image = rois_per_image
                # bg_rois_per_this_image = 0
                # print('fg',True)
                # sampling fg
                # rand_num = torch.floor(torch.rand(rois_per_image) * fg_num_rois).long().cuda()
                rand_num = np.floor(np.random.rand(rois_per_image) * fg_inds.numel())
                rand_num = torch.from_numpy(rand_num).type_as(gt_boxes).long()
                fg_inds = fg_inds[rand_num]
                fg_rois_per_this_image = rois_per_image
                bg_rois_per_this_image = 0
            elif bg_inds.numel() > 0:
                # to_replace = bg_inds.numel() < rois_per_image
                # bg_inds = bg_inds[torch.from_numpy(npr.choice(np.arange(0, bg_inds.numel()),size=int(rois_per_image),
                #                                               replace=to_replace)).long().to(gt_boxes.device)]
                # fg_rois_per_image = 0
                # bg_rois_per_this_image = rois_per_image
                # sampling bg
                # rand_num = torch.floor(torch.rand(rois_per_image) * bg_num_rois).long().cuda()
                rand_num = np.floor(np.random.rand(rois_per_image) *  bg_inds.numel())
                rand_num = torch.from_numpy(rand_num).type_as(gt_boxes).long()

                bg_inds = bg_inds[rand_num]
                bg_rois_per_this_image = rois_per_image
                fg_rois_per_this_image = 0

            else:
                raise ValueError("bg_num_rois = 0 and fg_num_rois = 0, this should not happen!")
                import pdb
                pdb.set_trace()

            # The indices that we're selecting (both fg and bg)
            keep_inds = torch.cat([fg_inds, bg_inds], 0)#選取的索引
            # Select sampled values from various arrays:
            labels_batch[i].copy_(labels[i][keep_inds])#選出對應標籤
            # Clamp labels for the background RoIs to 0
            if fg_rois_per_image < rois_per_image:
                labels_batch[i][fg_rois_per_image:] = 0#將前景之後的都設置爲背景
            rois_batch[i] = all_rois[i][keep_inds]
            rois_batch[i, :, 0] = i
            gt_rois_batch[i] = gt_boxes[i][gt_assignment[i][keep_inds]]
        bbox_target_data = self._compute_targets( rois_batch[:,:, 1:5], gt_rois_batch[:,:, :4],labels_batch)#計算目標
        bbox_targets, bbox_inside_weights = self._get_bbox_regression_labels(bbox_target_data,  num_classes)
        bbox_outside_weights = (bbox_inside_weights > 0).float()


        '''
        rois_batch         : [batch, rois_per_image, 5]
        labels_batch       : [batch, rois_per_image, 1]
        bbox_targets       : [batch, rois_per_image, 4]
        bbox_inside_weights: [batch, rois_per_image, 4]
        bbox_inside_weights: [batch, rois_per_image, 4]
        '''
        return rois_batch, labels_batch,  bbox_targets, bbox_inside_weights,bbox_outside_weights

這裏的輸入是

all_rois:   [batch , post_nms_topN + gt_num,5]  (  0, x1, y1, x2,    y2)這裏拼接了gt_bbox,拼接的部分label爲0
gt_boxes:   [batch ,                  gt_num ,5] ( x1, y1, x2, y2, label)
fg_rois_per_image: fg_num#每張圖的fg roi個數
rois_per_image:    rois_num#每張圖的rois總數
self.nclasses:      nclasses#總的類別數

與rpn_proposals_target類似,首先計算overlap,然後通過fg_num\bg_num以及fg_thresh、bg_thresh_hi、gb_thresh_lo進行篩選。在labels_batch中把前景之外的都設置爲0,也就是背景。rois_batch中放置的是rpn_proposal但是dim=0爲批量i

gt_rois_batch也就是對應的獲取到的gt_boxes。這裏面也涉及到了offset,對於數據不對齊的操作非常有效。

三、計算rcnn迴歸的bbox目標

    def _compute_targets(self,ex_rois, gt_rois,labels_batch):
        """Compute bounding-box regression targets for an image."""
        # Inputs are tensor

        assert ex_rois.shape[1] == gt_rois.shape[1]
        assert ex_rois.shape[2] == 4
        assert gt_rois.shape[2] == 4

        batch_size = ex_rois.size(0)
        rois_per_image = ex_rois.size(1)

        targets = bbox_transform_batch(ex_rois, gt_rois)
        # print(targets.is_cuda,self.bbox_normalize_means.is_cuda)
        self.bbox_normalize_means = self.bbox_normalize_means.to(targets.device)
        self.bbox_normalize_stds = self.bbox_normalize_stds.to(targets.device)
        if cfg['bbox_normalize_targets_precomputed']:
            # Optionally normalize targets by a precomputed mean and stdev
            targets = ((targets - self.bbox_normalize_means.expand_as(targets)) / self.bbox_normalize_stds.expand_as(targets))
        return torch.cat([labels_batch.unsqueeze(2), targets], 2)

也就是將rpn生成的proposal生成爲對應的兩點座標的表現形式,並且利用超參bbox_normalize進行歸一化。

返回的target中dim=0中包含label

四、獲取對應的bbox標籤

    def _get_bbox_regression_labels(self, bbox_target_data, num_classes):
        """Bounding-box regression targets (bbox_target_data) are stored in a
      compact form b x N x (class, tx, ty, tw, th)

      This function expands those targets into the 4-of-4*K representation used
      by the network (i.e. only one class has non-zero targets).

      Returns:
          bbox_target (ndarray): b x N x 4K blob of regression targets
          bbox_inside_weights (ndarray): b x N x 4K blob of loss weights
      """
        # Inputs are tensor
        batch_size = bbox_target_data.size(0)
        rois_per_image = bbox_target_data.size(1)
        clss = bbox_target_data[:,:,0]
        bbox_targets = bbox_target_data.new_zeros(batch_size,rois_per_image, 4*num_classes )
        bbox_inside_weights = bbox_target_data.new(bbox_targets.size()).zero_()
        for b in range(batch_size):
            if clss[b].sum() == 0: continue
            inds = torch.nonzero(clss[b] > 0).view(-1)
            if inds.numel()>0:
                temp_clss = clss[b][inds].contiguous().view(-1, 1)
                dim1_inds = inds.unsqueeze(1).expand(inds.size(0), 4)
                dim2_inds = torch.cat([4 * temp_clss, 4 * temp_clss + 1, 4 * temp_clss + 2, 4 * temp_clss + 3], 1).long()
                bbox_targets[b, dim1_inds, dim2_inds] = bbox_target_data[b, inds, 1:]
                bbox_inside_weights[b, dim1_inds, dim2_inds] = self.bbox_inside_weights.to(bbox_inside_weights.device)
        return bbox_targets, bbox_inside_weights

這裏主要是需要將bbox_target轉變爲批量形式。

五、輸出

'''
rois_batch         : [batch, rois_per_image, 5]
labels_batch       : [batch, rois_per_image, 1]
bbox_targets       : [batch, rois_per_image, 4]
bbox_inside_weights: [batch, rois_per_image, 4]
bbox_inside_weights: [batch, rois_per_image, 4]
'''
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章