mmdetection(1) : FCOS 代碼解析

原文我就不貼了,說一下感受吧!

從檢測方法出來,自我感覺一直不協調,現有的方法如fast系列一直比較複雜,強大的神經網絡應該是簡單高效的,one-stage從yolo出來後感覺好了很多,但是在最後的map上做roi anchor一直感覺特別冗餘,還好corner出來了,但其也存在的問題,不同的pool結構,還有總感覺這種方式怪怪的。然後FCOS出來了,完全感受到了高效和簡單,在此膜拜一下大神,感覺神經網絡就應該這樣,以最簡單的方式,取得很好的效果。

FCOS: Fully Convolutional One-Stage Object Detection

簡單介紹一下,方法很簡單,好文。
在這裏插入圖片描述

看到上圖了嗎,沒錯,就是這麼粗暴,像素級的預測t,l,r,b,當然,上面存在着一個問題,就是一個點包含兩個框,怎麼搞,當然選小的了,不不不,FPN啊:

在這裏插入圖片描述

稍微解釋一下最後的框框,x4是4層卷積,class分類,距離中心的loss,迴歸,
在這裏插入圖片描述
然後是中心度的定義,二進制交叉熵做loss。
在這裏插入圖片描述

ok,講完了, 簡單吧,粗暴吧,開始實現吧!


更新:
非常好用個點FCOS得配點代碼不是,近期研究mmdetection,感覺很好用,特來吧源碼解釋一下

def loss(self,
             cls_scores,
             bbox_preds,
             centernesses,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
 
        labels, bbox_targets = self.fcos_target(all_level_points, gt_bboxes,
                                                gt_labels)  #生成label
        flatten_cls_scores = [
            cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
            for cls_score in cls_scores
        ]
        flatten_bbox_preds = [
            bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
            for bbox_pred in bbox_preds
        ]
        flatten_centerness = [
            centerness.permute(0, 2, 3, 1).reshape(-1)
            for centerness in centernesses
        ]
        flatten_cls_scores = torch.cat(flatten_cls_scores)
        flatten_bbox_preds = torch.cat(flatten_bbox_preds)
        flatten_centerness = torch.cat(flatten_centerness)
        flatten_labels = torch.cat(labels)
        flatten_bbox_targets = torch.cat(bbox_targets)
     
        loss_cls = self.loss_cls(
            flatten_cls_scores, flatten_labels,
            avg_factor=num_pos + num_imgs)  # cls loss

        pos_bbox_preds = flatten_bbox_preds[pos_inds]
        pos_bbox_targets = flatten_bbox_targets[pos_inds]
        pos_centerness = flatten_centerness[pos_inds]
        pos_centerness_targets = self.centerness_target(pos_bbox_targets)


        pos_points = flatten_points[pos_inds]
        pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds)
        pos_decoded_target_preds = distance2bbox(pos_points,
                                                 pos_bbox_targets)
        # centerness weighted iou loss
        loss_bbox = self.loss_bbox(
            pos_decoded_bbox_preds,
            pos_decoded_target_preds,
            weight=pos_centerness_targets,
            avg_factor=pos_centerness_targets.sum())  #llox loss
        loss_centerness = self.loss_centerness(pos_centerness,
                                               pos_centerness_targets)  #center loss 交叉熵loss
      
        return dict(
            loss_cls=loss_cls,
            loss_bbox=loss_bbox,
            loss_centerness=loss_centerness)
            

其中生成中心點啊方法,和原文一樣:

def centerness_target(self, pos_bbox_targets):
        # only calculate pos centerness targets, otherwise there may be nan
        left_right = pos_bbox_targets[:, [0, 2]]
        top_bottom = pos_bbox_targets[:, [1, 3]]
        centerness_targets = (
            left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * (
                top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
        return torch.sqrt(centerness_targets)

具體loss參考mmdet
focalloss+IOUloss+ CEloss

bbox_head=dict(
        type='FCOSHead',
        num_classes=81,
        in_channels=256,
        stacked_convs=4,
        feat_channels=256,
        strides=[8, 16, 32, 64, 128],
        loss_cls=dict(
            type='FocalLoss',
            use_sigmoid=True,
            gamma=2.0,
            alpha=0.25,
            loss_weight=1.0),
        loss_bbox=dict(type='IoULoss', loss_weight=1.0),
        loss_centerness=dict(
            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)))
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章