faster rcnn代碼解讀參考
:https://github.com/adityaarun1/pytorch_fast-er_rcnn
https://github.com/jwyang/faster-rcnn.pytorch
之前rpn的anchor生成和target以及loss都有了,rpn環節以及是完整的了,下面就是rcnn環節。rcnn的輸入其實就是rpn的輸出。
class rcnn_target_layer(nn.Module):
"""
Assign object detection proposals to ground-truth targets. Produces proposal
classification labels and bounding-box regression targets.
"""
def __init__(self,nclasses):
super(rcnn_target_layer, self).__init__()
self.bbox_normalize_means = torch.FloatTensor(cfg['bbox_normalize_means'])
self.bbox_normalize_stds = torch.FloatTensor(cfg['bbox_normalize_stds'])
self.bbox_inside_weights = torch.FloatTensor(cfg['bbox_inside_weights'])
self.nclasses = nclasses
# self.bbox_normalize_means = self.bbox_normalize_means.type_as(gt_boxes)
# self.bbox_normalize_stds = self.bbox_normalize_stds.type_as(gt_boxes)
# self.bbox_inside_weights = self.bbox_inside_weights.type_as(gt_boxes)
def forward(self, rpn_proposal, gt_boxes):
'''
:param rpn_proposal: [batch , post_nms_topN, 5] #有分數全0部分
:param gt_boxes: [batch, gt_num ,5] ( x1, y1, x2, y2, label)
:return:
Assign object detection proposals to ground-truth targets. Produces proposal
classification labels and bounding-box regression targets.
# Proposal ROIs (0, x1, y1, x2, y2) coming from RPN
# (i.e., rpn.proposal_layer.ProposalLayer), or any other source
'''
all_rois = rpn_proposal # [batch , post_nms_topN, 5]
if cfg['use_gt']:
gt_boxes_append = gt_boxes.new(gt_boxes.size()).zero_()
gt_boxes_append[:, :, 1:5] = gt_boxes[:, :, :4]# 將gt第一列至0表示類別 (0 , x1, y1, x2, y2)
gt_boxes_append =gt_boxes_append.to(all_rois.device)
# Include ground-truth boxes in the set of candidate rois
all_rois = torch.cat([all_rois, gt_boxes_append], 1)
num_images = gt_boxes.shape[0]# 圖像大小實際上batchsize大小,這裏就是1了
rois_per_image = int(cfg['batch_size'] / num_images)# rcnn總的批量大小/圖像數,也就是每張圖實際上可以計算的rois的個數
fg_rois_per_image = int(np.round(cfg['fg_fraction'] * rois_per_image))#每張圖中前景個數
fg_rois_per_image = 1 if fg_rois_per_image == 0 else fg_rois_per_image#至少有一個前景
# Sample rois with classification labels and bounding box regression targets
'''
in:
all_rois: [batch , post_nms_topN + gt_num,5] ( 0, x1, y1, x2, y2)
all_scores: [batch , post_nms_topN + gt_num,1 ]
gt_boxes: [batch , gt_num ,5] ( x1, y1, x2, y2, label)
fg_rois_per_image: fg_num
rois_per_image: rois_num
self.nclasses: nclasses
out:
rcnn_rois : [batch, rois_per_image, 5]
rcnn_labels : [batch, rois_per_image, 1]
bbox_targets : [batch, rois_per_image, 4]
rcnn_bbox_inside_weights : [batch, rois_per_image, 4]
rcnn_bbox_outside_weights: [batch, rois_per_image, 4]
'''
rcnn_rois,rcnn_labels , rcnn_bbox_targets, rcnn_bbox_inside_weights,rcnn_bbox_outside_weights \
= self._sample_rois( all_rois, gt_boxes, fg_rois_per_image, rois_per_image, self.nclasses)
return rcnn_rois, rcnn_labels, rcnn_bbox_targets, rcnn_bbox_inside_weights, rcnn_bbox_outside_weights
def backward(self, top, propagate_down, bottom):
"""This layer does not propagate gradients."""
pass
def reshape(self, bottom, top):
"""Reshaping happens during the call to forward."""
pass
def _sample_rois(self, all_rois, gt_boxes, fg_rois_per_image,
rois_per_image, num_classes):
'''
:param all_rois: [batch ,post_nms_topN + gt_num,5]( 0, x1, y1, x2, y2)
:param gt_boxes: [batch, gt_num ,5] ( x1, y1, x2, y2, label)
:param fg_rois_per_image:fg_num
:param rois_per_image: rois_num
:param num_classes: nclasses
:return:
Generate a random sample of RoIs comprising foreground and background
examples.
'''
# overlaps: (rois x gt_boxes)
# overlaps = bbox_overlaps_batch(all_rois[:, 1:5].data, gt_boxes[:, :4].data)
# [batch, post_nms_topN + gt_num, 5]、[batch, gt_num ,5] ( x1, y1, x2, y2, label)
# -> [batch, post_nms_topN + gt_num, gt_num]
gt_boxes = gt_boxes.to(all_rois.device)
# [batch ,post_nms_topN + gt_num,gt_num]
overlaps = bbox_overlaps_batch( all_rois, gt_boxes)#計算生成的anchor和gt_boxes的overlap
max_overlaps, gt_assignment = torch.max(overlaps, 2) # 取出overlap在axis=2維最大值及對應的索引
batch_size = overlaps.size(0)
num_proposal = overlaps.size(1) # post_nms_topN + gt_num
num_boxes_per_img = overlaps.size(2) # gt_num
offset = torch.arange(0, batch_size) * gt_boxes.size(1)
offset = offset.view(-1, 1).type_as(gt_assignment) + gt_assignment
# labels = gt_boxes[:,:,4].contiguous().view(-1).index((offset.view(-1),)).view(batch_size, -1)
# [batch, gt_num, 5](x1, y1, x2, y2, label)->[batch, post_nms_topN + gt_num ]
labels = gt_boxes[:, :, 4].contiguous().view(-1)[(offset.view(-1),)].view(batch_size, -1)#將label縱向排列
labels_batch = labels.new(batch_size, rois_per_image).zero_() # [batch_size, rois_per_image]
rois_batch = all_rois.new(batch_size, rois_per_image, 5).zero_() # [batch_size, rois_per_image, 5]
gt_rois_batch = all_rois.new(batch_size, rois_per_image, 5).zero_()# [batch_size, rois_per_image. 5]
for i in range(batch_size):
# Select foreground RoIs as those with >= FG_THRESH overlap
fg_inds = (max_overlaps[i] >= cfg['fg_thresh']).nonzero().view(-1) # 前景索引
# Guard against the case when an image has fewer than fg_rois_per_image
# Select background RoIs as those within [BG_THRESH_LO, BG_THRESH_HI)
bg_inds = ((max_overlaps[i] < cfg['bg_thresh_hi'])&(max_overlaps[i] >= cfg['bg_thresh_lo'])).nonzero().view(-1)#背景索引
# Small modification to the original version where we ensure a fixed number of regions are sampled
if fg_inds.numel() > 0 and bg_inds.numel() > 0:
fg_rois_per_image = min(fg_rois_per_image, fg_inds.numel())#校準一下前景個數
# torch.randperm seems has a bug on multi-gpu setting that cause the segfault.
rand_num = torch.from_numpy(np.random.permutation(fg_inds.numel())).type_as(gt_boxes).long()
fg_inds = fg_inds[rand_num[:fg_rois_per_image]]
# fg_inds = fg_inds[torch.from_numpy(npr.choice( np.arange(0, fg_inds.numel()), size=int(fg_rois_per_image),
# replace=False)).long().to(gt_boxes.device)]
# sampling bg
bg_rois_per_image = rois_per_image - fg_rois_per_image
# to_replace = bg_inds.numel() < bg_rois_per_image
# bg_inds = bg_inds[torch.from_numpy( npr.choice(np.arange(0, bg_inds.numel()), size=int(bg_rois_per_image),
# replace=to_replace)).long().to(gt_boxes.device)]
# Seems torch.rand has a bug, it will generate very large number and make an error.
# We use numpy rand instead.
# rand_num = (torch.rand(bg_rois_per_this_image) * bg_num_rois).long().cuda()
rand_num = np.floor(np.random.rand(bg_rois_per_image) * bg_inds.numel())
rand_num = torch.from_numpy(rand_num).type_as(gt_boxes).long()
bg_inds = bg_inds[rand_num]
elif fg_inds.numel() > 0:
# sampling fg
# to_replace = fg_inds.numel() < rois_per_image
# fg_inds = fg_inds[torch.from_numpy( npr.choice( np.arange(0, fg_inds.numel()), size=int(rois_per_image),
# replace=to_replace)).long().to(gt_boxes.device)]
# fg_rois_per_image = rois_per_image
# bg_rois_per_this_image = 0
# print('fg',True)
# sampling fg
# rand_num = torch.floor(torch.rand(rois_per_image) * fg_num_rois).long().cuda()
rand_num = np.floor(np.random.rand(rois_per_image) * fg_inds.numel())
rand_num = torch.from_numpy(rand_num).type_as(gt_boxes).long()
fg_inds = fg_inds[rand_num]
fg_rois_per_this_image = rois_per_image
bg_rois_per_this_image = 0
elif bg_inds.numel() > 0:
# to_replace = bg_inds.numel() < rois_per_image
# bg_inds = bg_inds[torch.from_numpy(npr.choice(np.arange(0, bg_inds.numel()),size=int(rois_per_image),
# replace=to_replace)).long().to(gt_boxes.device)]
# fg_rois_per_image = 0
# bg_rois_per_this_image = rois_per_image
# sampling bg
# rand_num = torch.floor(torch.rand(rois_per_image) * bg_num_rois).long().cuda()
rand_num = np.floor(np.random.rand(rois_per_image) * bg_inds.numel())
rand_num = torch.from_numpy(rand_num).type_as(gt_boxes).long()
bg_inds = bg_inds[rand_num]
bg_rois_per_this_image = rois_per_image
fg_rois_per_this_image = 0
else:
raise ValueError("bg_num_rois = 0 and fg_num_rois = 0, this should not happen!")
import pdb
pdb.set_trace()
# The indices that we're selecting (both fg and bg)
keep_inds = torch.cat([fg_inds, bg_inds], 0)#選取的索引
# Select sampled values from various arrays:
labels_batch[i].copy_(labels[i][keep_inds])#選出對應標籤
# Clamp labels for the background RoIs to 0
if fg_rois_per_image < rois_per_image:
labels_batch[i][fg_rois_per_image:] = 0#將前景之後的都設置爲背景
rois_batch[i] = all_rois[i][keep_inds]
rois_batch[i, :, 0] = i
gt_rois_batch[i] = gt_boxes[i][gt_assignment[i][keep_inds]]
bbox_target_data = self._compute_targets( rois_batch[:,:, 1:5], gt_rois_batch[:,:, :4],labels_batch)#計算目標
bbox_targets, bbox_inside_weights = self._get_bbox_regression_labels(bbox_target_data, num_classes)
bbox_outside_weights = (bbox_inside_weights > 0).float()
'''
rois_batch : [batch, rois_per_image, 5]
labels_batch : [batch, rois_per_image, 1]
bbox_targets : [batch, rois_per_image, 4]
bbox_inside_weights: [batch, rois_per_image, 4]
bbox_inside_weights: [batch, rois_per_image, 4]
'''
return rois_batch, labels_batch, bbox_targets, bbox_inside_weights,bbox_outside_weights
def _compute_targets(self,ex_rois, gt_rois,labels_batch):
"""Compute bounding-box regression targets for an image."""
# Inputs are tensor
assert ex_rois.shape[1] == gt_rois.shape[1]
assert ex_rois.shape[2] == 4
assert gt_rois.shape[2] == 4
batch_size = ex_rois.size(0)
rois_per_image = ex_rois.size(1)
targets = bbox_transform_batch(ex_rois, gt_rois)
# print(targets.is_cuda,self.bbox_normalize_means.is_cuda)
self.bbox_normalize_means = self.bbox_normalize_means.to(targets.device)
self.bbox_normalize_stds = self.bbox_normalize_stds.to(targets.device)
if cfg['bbox_normalize_targets_precomputed']:
# Optionally normalize targets by a precomputed mean and stdev
targets = ((targets - self.bbox_normalize_means.expand_as(targets)) / self.bbox_normalize_stds.expand_as(targets))
return torch.cat([labels_batch.unsqueeze(2), targets], 2)
def _get_bbox_regression_labels(self, bbox_target_data, num_classes):
"""Bounding-box regression targets (bbox_target_data) are stored in a
compact form b x N x (class, tx, ty, tw, th)
This function expands those targets into the 4-of-4*K representation used
by the network (i.e. only one class has non-zero targets).
Returns:
bbox_target (ndarray): b x N x 4K blob of regression targets
bbox_inside_weights (ndarray): b x N x 4K blob of loss weights
"""
# Inputs are tensor
batch_size = bbox_target_data.size(0)
rois_per_image = bbox_target_data.size(1)
clss = bbox_target_data[:,:,0]
bbox_targets = bbox_target_data.new_zeros(batch_size,rois_per_image, 4*num_classes )
bbox_inside_weights = bbox_target_data.new(bbox_targets.size()).zero_()
for b in range(batch_size):
if clss[b].sum() == 0: continue
inds = torch.nonzero(clss[b] > 0).view(-1)
if inds.numel()>0:
temp_clss = clss[b][inds].contiguous().view(-1, 1)
dim1_inds = inds.unsqueeze(1).expand(inds.size(0), 4)
dim2_inds = torch.cat([4 * temp_clss, 4 * temp_clss + 1, 4 * temp_clss + 2, 4 * temp_clss + 3], 1).long()
bbox_targets[b, dim1_inds, dim2_inds] = bbox_target_data[b, inds, 1:]
bbox_inside_weights[b, dim1_inds, dim2_inds] = self.bbox_inside_weights.to(bbox_inside_weights.device)
return bbox_targets, bbox_inside_weights
一、代碼解讀
輸入
rpn_proposal: [batch , post_nms_topN, 5] #這是之前rpn_proposal_gen中的部分,也許個數是不夠post_nms_topN的,因此裏面有幾個是沒有意義的,另外這裏的label是-1,0,1在dim=0
param gt_boxes: [batch, gt_num ,5] ( x1, y1, x2, y2, label),這裏的gt_label是真實的,dim=4
- 現將rpn_proposal後面拼接gt_boxes,拼接部分的dim=0爲0
- 然後根據rcnn裏面batch_size計算每張圖應該有多少個rois_per_image,然後就是每張圖中前景個數fg_rois_per_image。
- 對拼接的roi進行採樣
二、roi採樣
def _sample_rois(self, all_rois, gt_boxes, fg_rois_per_image,
rois_per_image, num_classes):
'''
:param all_rois: [batch ,post_nms_topN + gt_num,5]( 0, x1, y1, x2, y2)
:param gt_boxes: [batch, gt_num ,5] ( x1, y1, x2, y2, label)
:param fg_rois_per_image:fg_num
:param rois_per_image: rois_num
:param num_classes: nclasses
:return:
Generate a random sample of RoIs comprising foreground and background
examples.
'''
# overlaps: (rois x gt_boxes)
# overlaps = bbox_overlaps_batch(all_rois[:, 1:5].data, gt_boxes[:, :4].data)
# [batch, post_nms_topN + gt_num, 5]、[batch, gt_num ,5] ( x1, y1, x2, y2, label)
# -> [batch, post_nms_topN + gt_num, gt_num]
gt_boxes = gt_boxes.to(all_rois.device)
# [batch ,post_nms_topN + gt_num,gt_num]
overlaps = bbox_overlaps_batch( all_rois, gt_boxes)#計算生成的anchor和gt_boxes的overlap
max_overlaps, gt_assignment = torch.max(overlaps, 2) # 取出overlap在axis=2維最大值及對應的索引
batch_size = overlaps.size(0)
num_proposal = overlaps.size(1) # post_nms_topN + gt_num
num_boxes_per_img = overlaps.size(2) # gt_num
offset = torch.arange(0, batch_size) * gt_boxes.size(1)
offset = offset.view(-1, 1).type_as(gt_assignment) + gt_assignment
# labels = gt_boxes[:,:,4].contiguous().view(-1).index((offset.view(-1),)).view(batch_size, -1)
# [batch, gt_num, 5](x1, y1, x2, y2, label)->[batch, post_nms_topN + gt_num ]
labels = gt_boxes[:, :, 4].contiguous().view(-1)[(offset.view(-1),)].view(batch_size, -1)#將label縱向排列
labels_batch = labels.new(batch_size, rois_per_image).zero_() # [batch_size, rois_per_image]
rois_batch = all_rois.new(batch_size, rois_per_image, 5).zero_() # [batch_size, rois_per_image, 5]
gt_rois_batch = all_rois.new(batch_size, rois_per_image, 5).zero_()# [batch_size, rois_per_image. 5]
for i in range(batch_size):
# Select foreground RoIs as those with >= FG_THRESH overlap
fg_inds = (max_overlaps[i] >= cfg['fg_thresh']).nonzero().view(-1) # 前景索引
# Guard against the case when an image has fewer than fg_rois_per_image
# Select background RoIs as those within [BG_THRESH_LO, BG_THRESH_HI)
bg_inds = ((max_overlaps[i] < cfg['bg_thresh_hi'])&(max_overlaps[i] >= cfg['bg_thresh_lo'])).nonzero().view(-1)#背景索引
# Small modification to the original version where we ensure a fixed number of regions are sampled
if fg_inds.numel() > 0 and bg_inds.numel() > 0:
fg_rois_per_image = min(fg_rois_per_image, fg_inds.numel())#校準一下前景個數
# torch.randperm seems has a bug on multi-gpu setting that cause the segfault.
rand_num = torch.from_numpy(np.random.permutation(fg_inds.numel())).type_as(gt_boxes).long()
fg_inds = fg_inds[rand_num[:fg_rois_per_image]]
# fg_inds = fg_inds[torch.from_numpy(npr.choice( np.arange(0, fg_inds.numel()), size=int(fg_rois_per_image),
# replace=False)).long().to(gt_boxes.device)]
# sampling bg
bg_rois_per_image = rois_per_image - fg_rois_per_image
# to_replace = bg_inds.numel() < bg_rois_per_image
# bg_inds = bg_inds[torch.from_numpy( npr.choice(np.arange(0, bg_inds.numel()), size=int(bg_rois_per_image),
# replace=to_replace)).long().to(gt_boxes.device)]
# Seems torch.rand has a bug, it will generate very large number and make an error.
# We use numpy rand instead.
# rand_num = (torch.rand(bg_rois_per_this_image) * bg_num_rois).long().cuda()
rand_num = np.floor(np.random.rand(bg_rois_per_image) * bg_inds.numel())
rand_num = torch.from_numpy(rand_num).type_as(gt_boxes).long()
bg_inds = bg_inds[rand_num]
elif fg_inds.numel() > 0:
# sampling fg
# to_replace = fg_inds.numel() < rois_per_image
# fg_inds = fg_inds[torch.from_numpy( npr.choice( np.arange(0, fg_inds.numel()), size=int(rois_per_image),
# replace=to_replace)).long().to(gt_boxes.device)]
# fg_rois_per_image = rois_per_image
# bg_rois_per_this_image = 0
# print('fg',True)
# sampling fg
# rand_num = torch.floor(torch.rand(rois_per_image) * fg_num_rois).long().cuda()
rand_num = np.floor(np.random.rand(rois_per_image) * fg_inds.numel())
rand_num = torch.from_numpy(rand_num).type_as(gt_boxes).long()
fg_inds = fg_inds[rand_num]
fg_rois_per_this_image = rois_per_image
bg_rois_per_this_image = 0
elif bg_inds.numel() > 0:
# to_replace = bg_inds.numel() < rois_per_image
# bg_inds = bg_inds[torch.from_numpy(npr.choice(np.arange(0, bg_inds.numel()),size=int(rois_per_image),
# replace=to_replace)).long().to(gt_boxes.device)]
# fg_rois_per_image = 0
# bg_rois_per_this_image = rois_per_image
# sampling bg
# rand_num = torch.floor(torch.rand(rois_per_image) * bg_num_rois).long().cuda()
rand_num = np.floor(np.random.rand(rois_per_image) * bg_inds.numel())
rand_num = torch.from_numpy(rand_num).type_as(gt_boxes).long()
bg_inds = bg_inds[rand_num]
bg_rois_per_this_image = rois_per_image
fg_rois_per_this_image = 0
else:
raise ValueError("bg_num_rois = 0 and fg_num_rois = 0, this should not happen!")
import pdb
pdb.set_trace()
# The indices that we're selecting (both fg and bg)
keep_inds = torch.cat([fg_inds, bg_inds], 0)#選取的索引
# Select sampled values from various arrays:
labels_batch[i].copy_(labels[i][keep_inds])#選出對應標籤
# Clamp labels for the background RoIs to 0
if fg_rois_per_image < rois_per_image:
labels_batch[i][fg_rois_per_image:] = 0#將前景之後的都設置爲背景
rois_batch[i] = all_rois[i][keep_inds]
rois_batch[i, :, 0] = i
gt_rois_batch[i] = gt_boxes[i][gt_assignment[i][keep_inds]]
bbox_target_data = self._compute_targets( rois_batch[:,:, 1:5], gt_rois_batch[:,:, :4],labels_batch)#計算目標
bbox_targets, bbox_inside_weights = self._get_bbox_regression_labels(bbox_target_data, num_classes)
bbox_outside_weights = (bbox_inside_weights > 0).float()
'''
rois_batch : [batch, rois_per_image, 5]
labels_batch : [batch, rois_per_image, 1]
bbox_targets : [batch, rois_per_image, 4]
bbox_inside_weights: [batch, rois_per_image, 4]
bbox_inside_weights: [batch, rois_per_image, 4]
'''
return rois_batch, labels_batch, bbox_targets, bbox_inside_weights,bbox_outside_weights
這裏的輸入是
all_rois: [batch , post_nms_topN + gt_num,5] ( 0, x1, y1, x2, y2)這裏拼接了gt_bbox,拼接的部分label爲0 gt_boxes: [batch , gt_num ,5] ( x1, y1, x2, y2, label) fg_rois_per_image: fg_num#每張圖的fg roi個數 rois_per_image: rois_num#每張圖的rois總數 self.nclasses: nclasses#總的類別數
與rpn_proposals_target類似,首先計算overlap,然後通過fg_num\bg_num以及fg_thresh、bg_thresh_hi、gb_thresh_lo進行篩選。在labels_batch中把前景之外的都設置爲0,也就是背景。rois_batch中放置的是rpn_proposal但是dim=0爲批量i
gt_rois_batch也就是對應的獲取到的gt_boxes。這裏面也涉及到了offset,對於數據不對齊的操作非常有效。
三、計算rcnn迴歸的bbox目標
def _compute_targets(self,ex_rois, gt_rois,labels_batch):
"""Compute bounding-box regression targets for an image."""
# Inputs are tensor
assert ex_rois.shape[1] == gt_rois.shape[1]
assert ex_rois.shape[2] == 4
assert gt_rois.shape[2] == 4
batch_size = ex_rois.size(0)
rois_per_image = ex_rois.size(1)
targets = bbox_transform_batch(ex_rois, gt_rois)
# print(targets.is_cuda,self.bbox_normalize_means.is_cuda)
self.bbox_normalize_means = self.bbox_normalize_means.to(targets.device)
self.bbox_normalize_stds = self.bbox_normalize_stds.to(targets.device)
if cfg['bbox_normalize_targets_precomputed']:
# Optionally normalize targets by a precomputed mean and stdev
targets = ((targets - self.bbox_normalize_means.expand_as(targets)) / self.bbox_normalize_stds.expand_as(targets))
return torch.cat([labels_batch.unsqueeze(2), targets], 2)
也就是將rpn生成的proposal生成爲對應的兩點座標的表現形式,並且利用超參bbox_normalize進行歸一化。
返回的target中dim=0中包含label
四、獲取對應的bbox標籤
def _get_bbox_regression_labels(self, bbox_target_data, num_classes):
"""Bounding-box regression targets (bbox_target_data) are stored in a
compact form b x N x (class, tx, ty, tw, th)
This function expands those targets into the 4-of-4*K representation used
by the network (i.e. only one class has non-zero targets).
Returns:
bbox_target (ndarray): b x N x 4K blob of regression targets
bbox_inside_weights (ndarray): b x N x 4K blob of loss weights
"""
# Inputs are tensor
batch_size = bbox_target_data.size(0)
rois_per_image = bbox_target_data.size(1)
clss = bbox_target_data[:,:,0]
bbox_targets = bbox_target_data.new_zeros(batch_size,rois_per_image, 4*num_classes )
bbox_inside_weights = bbox_target_data.new(bbox_targets.size()).zero_()
for b in range(batch_size):
if clss[b].sum() == 0: continue
inds = torch.nonzero(clss[b] > 0).view(-1)
if inds.numel()>0:
temp_clss = clss[b][inds].contiguous().view(-1, 1)
dim1_inds = inds.unsqueeze(1).expand(inds.size(0), 4)
dim2_inds = torch.cat([4 * temp_clss, 4 * temp_clss + 1, 4 * temp_clss + 2, 4 * temp_clss + 3], 1).long()
bbox_targets[b, dim1_inds, dim2_inds] = bbox_target_data[b, inds, 1:]
bbox_inside_weights[b, dim1_inds, dim2_inds] = self.bbox_inside_weights.to(bbox_inside_weights.device)
return bbox_targets, bbox_inside_weights
這裏主要是需要將bbox_target轉變爲批量形式。
五、輸出
''' rois_batch : [batch, rois_per_image, 5] labels_batch : [batch, rois_per_image, 1] bbox_targets : [batch, rois_per_image, 4] bbox_inside_weights: [batch, rois_per_image, 4] bbox_inside_weights: [batch, rois_per_image, 4] '''