所有代碼已上傳到本人github repository:https://github.com/zgcr/pytorch-ImageNet-CIFAR-COCO-VOC-training
如果覺得有用,請點個star喲!
下列代碼均在pytorch1.4版本中測試過,確認正確無誤。
迴歸預測轉換
模型訓練完成後,需要decode模型輸出才能進行測試。我們從RetinaNet類進行forward計算後可以得到cls heads和reg heads,但此時reg heads預測的是tx,ty,tw,th,我們需要使用對應的Anchor box座標將其轉換爲預測的box座標。座標的轉換規則就是從零實現RetinaNet(四)中box座標轉換爲迴歸標籤tx,ty,tw,th的逆運算。
迴歸預測轉換爲box預測的代碼實現如下:
def snap_tx_ty_tw_th_reg_heads_to_x1_y1_x2_y2_bboxes(
self, reg_heads, anchors):
"""
snap reg heads to pred bboxes
reg_heads:[anchor_nums,4],4:[tx,ty,tw,th]
anchors:[anchor_nums,4],4:[x_min,y_min,x_max,y_max]
"""
anchors_wh = anchors[:, 2:] - anchors[:, :2]
anchors_ctr = anchors[:, :2] + 0.5 * anchors_wh
device = anchors.device
factor = torch.tensor([[0.1, 0.1, 0.2, 0.2]]).to(device)
reg_heads = reg_heads * factor
pred_bboxes_wh = torch.exp(reg_heads[:, 2:]) * anchors_wh
pred_bboxes_ctr = reg_heads[:, :2] * anchors_wh + anchors_ctr
pred_bboxes_x_min_y_min = pred_bboxes_ctr - 0.5 * pred_bboxes_wh
pred_bboxes_x_max_y_max = pred_bboxes_ctr + 0.5 * pred_bboxes_wh
pred_bboxes = torch.cat(
[pred_bboxes_x_min_y_min, pred_bboxes_x_max_y_max], axis=1)
pred_bboxes = pred_bboxes.int()
pred_bboxes[:, 0] = torch.clamp(pred_bboxes[:, 0], min=0)
pred_bboxes[:, 1] = torch.clamp(pred_bboxes[:, 1], min=0)
pred_bboxes[:, 2] = torch.clamp(pred_bboxes[:, 2],
max=self.image_w - 1)
pred_bboxes[:, 3] = torch.clamp(pred_bboxes[:, 3],
max=self.image_h - 1)
# pred bboxes shape:[anchor_nums,4]
return pred_bboxes
NMS後處理
NMS後處理的標準方法是:先將所有候選目標按分類score從大到小排序,記錄所有候選目標的分類類別有哪幾種。然後開始遍歷探測到的這幾個類別,對於每個類別,提取出這個類別的所有候選目標(注意因爲我們一開始已經排過序了,所以按類別提取出來仍然是有序的),先把第一個目標提取到保留目標集合中,然後計算剩餘所有目標與該目標的IoU,IoU大於閾值的候選目標全部拋棄。對於RetinaNet,這個閾值爲0.5。然後剩餘沒有拋棄的目標重複上面過程,繼續把第一個目標提取到保留目標集合中,後面操作都是一樣的,直到沒有候選目標爲止,對該類候選目標的NMS就做完了。對所有類別都遍歷完,NMS就做完了。
在其他目標檢測代碼實現中,我發現有許多代碼在做NMS後處理時並沒有分類別來作NMS(即所有不同類別的候選目標一起作NMS)。因此我也嘗試了這種做法,發現這種做法總是比NMS的標準做法要低0.2~0.5個mAP左右,因此,在下面的代碼實現中,還是使用NMS的標準方法。
NMS後處理的代碼實現如下:
def nms(self, one_image_scores, one_image_classes, one_image_pred_bboxes):
"""
one_image_scores:[anchor_nums],4:classification predict scores
one_image_classes:[anchor_nums],class indexes for predict scores
one_image_pred_bboxes:[anchor_nums,4],4:x_min,y_min,x_max,y_max
"""
# Sort boxes
sorted_one_image_scores, sorted_one_image_scores_indexes = torch.sort(
one_image_scores, descending=True)
sorted_one_image_classes = one_image_classes[
sorted_one_image_scores_indexes]
sorted_one_image_pred_bboxes = one_image_pred_bboxes[
sorted_one_image_scores_indexes]
sorted_pred_bboxes_w_h = sorted_one_image_pred_bboxes[:,
2:] - sorted_one_image_pred_bboxes[:, :
2]
sorted_pred_bboxes_areas = sorted_pred_bboxes_w_h[:,
0] * sorted_pred_bboxes_w_h[:,
1]
detected_classes = torch.unique(sorted_one_image_classes, sorted=True)
keep_scores, keep_classes, keep_pred_bboxes = [], [], []
for detected_class in detected_classes:
single_class_scores = sorted_one_image_scores[
sorted_one_image_classes == detected_class]
single_class_pred_bboxes = sorted_one_image_pred_bboxes[
sorted_one_image_classes == detected_class]
single_class_pred_bboxes_areas = sorted_pred_bboxes_areas[
sorted_one_image_classes == detected_class]
single_class = sorted_one_image_classes[sorted_one_image_classes ==
detected_class]
single_keep_scores,single_keep_classes,single_keep_pred_bboxes=[],[],[]
while single_class_scores.numel() > 0:
top1_score, top1_class, top1_pred_bbox = single_class_scores[
0:1], single_class[0:1], single_class_pred_bboxes[0:1]
single_keep_scores.append(top1_score)
single_keep_classes.append(top1_class)
single_keep_pred_bboxes.append(top1_pred_bbox)
top1_areas = single_class_pred_bboxes_areas[0]
if single_class_scores.numel() == 1:
break
single_class_scores = single_class_scores[1:]
single_class = single_class[1:]
single_class_pred_bboxes = single_class_pred_bboxes[1:]
single_class_pred_bboxes_areas = single_class_pred_bboxes_areas[
1:]
overlap_area_top_left = torch.max(
single_class_pred_bboxes[:, :2], top1_pred_bbox[:, :2])
overlap_area_bot_right = torch.min(
single_class_pred_bboxes[:, 2:], top1_pred_bbox[:, 2:])
overlap_area_sizes = torch.clamp(overlap_area_bot_right -
overlap_area_top_left,
min=0)
overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:,
1]
# compute union_area
union_area = top1_areas + single_class_pred_bboxes_areas - overlap_area
union_area = torch.clamp(union_area, min=1e-4)
# compute ious for top1 pred_bbox and the other pred_bboxes
ious = overlap_area / union_area
single_class_scores = single_class_scores[
ious < self.nms_threshold]
single_class = single_class[ious < self.nms_threshold]
single_class_pred_bboxes = single_class_pred_bboxes[
ious < self.nms_threshold]
single_class_pred_bboxes_areas = single_class_pred_bboxes_areas[
ious < self.nms_threshold]
single_keep_scores = torch.cat(single_keep_scores, axis=0)
single_keep_classes = torch.cat(single_keep_classes, axis=0)
single_keep_pred_bboxes = torch.cat(single_keep_pred_bboxes,
axis=0)
keep_scores.append(single_keep_scores)
keep_classes.append(single_keep_classes)
keep_pred_bboxes.append(single_keep_pred_bboxes)
keep_scores = torch.cat(keep_scores, axis=0)
keep_classes = torch.cat(keep_classes, axis=0)
keep_pred_bboxes = torch.cat(keep_pred_bboxes, axis=0)
return keep_scores, keep_classes, keep_pred_bboxes
decode解碼
有了上面兩部分,現在我們可以開始decode解碼了。整個decode解碼的流程是:先將reg head的tx,ty,tw,th預測轉換爲box座標預測(需要使用Anchor座標信息),然後使用一個分類score閾值過濾到分類分數太低的候選目標,對於RetinaNet,這個閾值是0.05。然後,我們對剩下的候選目標NMS後處理,得到保留的候選目標。最後,我們還設置了一個max_detection_num,即確定最終輸出時保留多少個目標,對於COCO數據集,這個值爲100,因爲COCO數據集的圖片上沒有單張圖片標註了超過100個目標的情況。
decode解碼的代碼實現如下:
class RetinaDecoder(nn.Module):
def __init__(self,
image_w,
image_h,
min_score_threshold=0.05,
nms_threshold=0.5,
max_detection_num=100):
super(RetinaDecoder, self).__init__()
self.image_w = image_w
self.image_h = image_h
self.min_score_threshold = min_score_threshold
self.nms_threshold = nms_threshold
self.max_detection_num = max_detection_num
def forward(self, cls_heads, reg_heads, batch_anchors):
device = cls_heads[0].device
cls_heads = torch.cat(cls_heads, axis=1)
reg_heads = torch.cat(reg_heads, axis=1)
batch_anchors = torch.cat(batch_anchors, axis=1)
batch_scores, batch_classes, batch_pred_bboxes = [], [], []
for per_image_cls_heads, per_image_reg_heads, per_image_anchors in zip(
cls_heads, reg_heads, batch_anchors):
pred_bboxes = self.snap_tx_ty_tw_th_reg_heads_to_x1_y1_x2_y2_bboxes(
per_image_reg_heads, per_image_anchors)
scores, score_classes = torch.max(per_image_cls_heads, dim=1)
score_classes = score_classes[
scores > self.min_score_threshold].float()
pred_bboxes = pred_bboxes[
scores > self.min_score_threshold].float()
scores = scores[scores > self.min_score_threshold].float()
single_image_scores = (-1) * torch.ones(
(self.max_detection_num, ), device=device)
single_image_classes = (-1) * torch.ones(
(self.max_detection_num, ), device=device)
single_image_pred_bboxes = (-1) * torch.ones(
(self.max_detection_num, 4), device=device)
if scores.shape[0] != 0:
scores, score_classes, pred_bboxes = self.nms(
scores, score_classes, pred_bboxes)
sorted_keep_scores, sorted_keep_scores_indexes = torch.sort(
scores, descending=True)
sorted_keep_classes = score_classes[sorted_keep_scores_indexes]
sorted_keep_pred_bboxes = pred_bboxes[
sorted_keep_scores_indexes]
final_detection_num = min(self.max_detection_num,
sorted_keep_scores.shape[0])
single_image_scores[
0:final_detection_num] = sorted_keep_scores[
0:final_detection_num]
single_image_classes[
0:final_detection_num] = sorted_keep_classes[
0:final_detection_num]
single_image_pred_bboxes[
0:final_detection_num, :] = sorted_keep_pred_bboxes[
0:final_detection_num, :]
single_image_scores = single_image_scores.unsqueeze(0)
single_image_classes = single_image_classes.unsqueeze(0)
single_image_pred_bboxes = single_image_pred_bboxes.unsqueeze(0)
batch_scores.append(single_image_scores)
batch_classes.append(single_image_classes)
batch_pred_bboxes.append(single_image_pred_bboxes)
batch_scores = torch.cat(batch_scores, axis=0)
batch_classes = torch.cat(batch_classes, axis=0)
batch_pred_bboxes = torch.cat(batch_pred_bboxes, axis=0)
# batch_scores shape:[batch_size,max_detection_num]
# batch_classes shape:[batch_size,max_detection_num]
# batch_pred_bboxes shape[batch_size,max_detection_num,4]
return batch_scores, batch_classes, batch_pred_bboxes
def nms(self, one_image_scores, one_image_classes, one_image_pred_bboxes):
"""
one_image_scores:[anchor_nums],4:classification predict scores
one_image_classes:[anchor_nums],class indexes for predict scores
one_image_pred_bboxes:[anchor_nums,4],4:x_min,y_min,x_max,y_max
"""
# Sort boxes
sorted_one_image_scores, sorted_one_image_scores_indexes = torch.sort(
one_image_scores, descending=True)
sorted_one_image_classes = one_image_classes[
sorted_one_image_scores_indexes]
sorted_one_image_pred_bboxes = one_image_pred_bboxes[
sorted_one_image_scores_indexes]
sorted_pred_bboxes_w_h = sorted_one_image_pred_bboxes[:,
2:] - sorted_one_image_pred_bboxes[:, :
2]
sorted_pred_bboxes_areas = sorted_pred_bboxes_w_h[:,
0] * sorted_pred_bboxes_w_h[:,
1]
detected_classes = torch.unique(sorted_one_image_classes, sorted=True)
keep_scores, keep_classes, keep_pred_bboxes = [], [], []
for detected_class in detected_classes:
single_class_scores = sorted_one_image_scores[
sorted_one_image_classes == detected_class]
single_class_pred_bboxes = sorted_one_image_pred_bboxes[
sorted_one_image_classes == detected_class]
single_class_pred_bboxes_areas = sorted_pred_bboxes_areas[
sorted_one_image_classes == detected_class]
single_class = sorted_one_image_classes[sorted_one_image_classes ==
detected_class]
single_keep_scores,single_keep_classes,single_keep_pred_bboxes=[],[],[]
while single_class_scores.numel() > 0:
top1_score, top1_class, top1_pred_bbox = single_class_scores[
0:1], single_class[0:1], single_class_pred_bboxes[0:1]
single_keep_scores.append(top1_score)
single_keep_classes.append(top1_class)
single_keep_pred_bboxes.append(top1_pred_bbox)
top1_areas = single_class_pred_bboxes_areas[0]
if single_class_scores.numel() == 1:
break
single_class_scores = single_class_scores[1:]
single_class = single_class[1:]
single_class_pred_bboxes = single_class_pred_bboxes[1:]
single_class_pred_bboxes_areas = single_class_pred_bboxes_areas[
1:]
overlap_area_top_left = torch.max(
single_class_pred_bboxes[:, :2], top1_pred_bbox[:, :2])
overlap_area_bot_right = torch.min(
single_class_pred_bboxes[:, 2:], top1_pred_bbox[:, 2:])
overlap_area_sizes = torch.clamp(overlap_area_bot_right -
overlap_area_top_left,
min=0)
overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:,
1]
# compute union_area
union_area = top1_areas + single_class_pred_bboxes_areas - overlap_area
union_area = torch.clamp(union_area, min=1e-4)
# compute ious for top1 pred_bbox and the other pred_bboxes
ious = overlap_area / union_area
single_class_scores = single_class_scores[
ious < self.nms_threshold]
single_class = single_class[ious < self.nms_threshold]
single_class_pred_bboxes = single_class_pred_bboxes[
ious < self.nms_threshold]
single_class_pred_bboxes_areas = single_class_pred_bboxes_areas[
ious < self.nms_threshold]
single_keep_scores = torch.cat(single_keep_scores, axis=0)
single_keep_classes = torch.cat(single_keep_classes, axis=0)
single_keep_pred_bboxes = torch.cat(single_keep_pred_bboxes,
axis=0)
keep_scores.append(single_keep_scores)
keep_classes.append(single_keep_classes)
keep_pred_bboxes.append(single_keep_pred_bboxes)
keep_scores = torch.cat(keep_scores, axis=0)
keep_classes = torch.cat(keep_classes, axis=0)
keep_pred_bboxes = torch.cat(keep_pred_bboxes, axis=0)
return keep_scores, keep_classes, keep_pred_bboxes
def snap_tx_ty_tw_th_reg_heads_to_x1_y1_x2_y2_bboxes(
self, reg_heads, anchors):
"""
snap reg heads to pred bboxes
reg_heads:[anchor_nums,4],4:[tx,ty,tw,th]
anchors:[anchor_nums,4],4:[x_min,y_min,x_max,y_max]
"""
anchors_wh = anchors[:, 2:] - anchors[:, :2]
anchors_ctr = anchors[:, :2] + 0.5 * anchors_wh
device = anchors.device
factor = torch.tensor([[0.1, 0.1, 0.2, 0.2]]).to(device)
reg_heads = reg_heads * factor
pred_bboxes_wh = torch.exp(reg_heads[:, 2:]) * anchors_wh
pred_bboxes_ctr = reg_heads[:, :2] * anchors_wh + anchors_ctr
pred_bboxes_x_min_y_min = pred_bboxes_ctr - 0.5 * pred_bboxes_wh
pred_bboxes_x_max_y_max = pred_bboxes_ctr + 0.5 * pred_bboxes_wh
pred_bboxes = torch.cat(
[pred_bboxes_x_min_y_min, pred_bboxes_x_max_y_max], axis=1)
pred_bboxes = pred_bboxes.int()
pred_bboxes[:, 0] = torch.clamp(pred_bboxes[:, 0], min=0)
pred_bboxes[:, 1] = torch.clamp(pred_bboxes[:, 1], min=0)
pred_bboxes[:, 2] = torch.clamp(pred_bboxes[:, 2],
max=self.image_w - 1)
pred_bboxes[:, 3] = torch.clamp(pred_bboxes[:, 3],
max=self.image_h - 1)
# pred bboxes shape:[anchor_nums,4]
return pred_bboxes
這樣decode解碼部分就實現好了。