使用RFBNet訓練kaggle RSNA數據檢測胸片的肺炎

one stage 的RFBNet在保證速度的前提下,也有着不錯的精度,所以拿來訓練kaggle上的RSNA。這邊主要介紹下對RFBnet源碼修改支持RSNA的訓練,如果想看關於RSNA數據分析的,可以去看kaggle上的kernels。

數據集介紹

RSNA跟常見的檢測數據集(COCO,VOC,BDD100K,CITYSCAPE等)不一樣的一個地方就是,圖片中可能不存在標註,也就是說不存在foreground,我就隱隱覺得源碼可能不支持這種情況,果然寫完dataloader之後報錯了,然後就需要修改源碼了。

源碼修改

1.自己寫個支持RSNA的dataloader

大家都有自己的風格,主要就是:

1.用SimpleITK讀dicom

2.當前圖像沒有標註時,load annotation返回 np.zeros((1, 5))

2.修改multibox_loss.py

源碼會根據foreground的數量,按一定比例取一些background,但是如果沒有foreground,background也沒有,算正負樣本分類的交叉熵就會報錯。

我添加了一段邏輯,如果沒有foreground,就選擇4個background進行計算,對應下面代碼55-58。

    def forward(self, predictions, priors, targets):
        """Multibox Loss
        Args:
            predictions (tuple): A tuple containing loc preds, conf preds,
            and prior boxes from SSD net.
                conf shape: torch.size(batch_size,num_priors,num_classes)
                loc shape: torch.size(batch_size,num_priors,4)
                priors shape: torch.size(num_priors,4)

            ground_truth (tensor): Ground truth boxes and labels for a batch,
                shape: [batch_size,num_objs,5] (last idx is the label).
        """

        loc_data, conf_data = predictions
        priors = priors
        num = loc_data.size(0)
        num_priors = (priors.size(0))
        num_classes = self.num_classes

        # match priors (default boxes) and ground truth boxes
        loc_t = torch.Tensor(num, num_priors, 4)
        conf_t = torch.LongTensor(num, num_priors)
        for idx in range(num):
            truths = targets[idx][:, :-1].data
            labels = targets[idx][:, -1].data
            defaults = priors.data
            match(self.threshold, truths, defaults, self.variance, labels, loc_t, conf_t, idx)
        if GPU:
            loc_t = loc_t.cuda()
            conf_t = conf_t.cuda()
        # wrap targets
        loc_t = Variable(loc_t, requires_grad=False)
        conf_t = Variable(conf_t, requires_grad=False)

        pos = conf_t > 0

        # Localization Loss (Smooth L1)
        # Shape: [batch,num_priors,4]
        pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
        loc_p = loc_data[pos_idx].view(-1, 4)
        loc_t = loc_t[pos_idx].view(-1, 4)
        loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average=False)

        # Compute max conf across batch for hard negative mining
        batch_conf = conf_data.view(-1, self.num_classes)
        loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))

        # Hard Negative Mining
        loss_c[pos.view(-1, 1)] = 0  # filter out pos boxes for now
        loss_c = loss_c.view(num, -1)
        _, loss_idx = loss_c.sort(1, descending=True)
        _, idx_rank = loss_idx.sort(1)
        num_pos = pos.long().sum(1, keepdim=True)

        constant_min = torch.ones(num_pos.shape, dtype=torch.int64) * 4
        neg_min = torch.max(self.negpos_ratio * num_pos, constant_min.cuda())
        num_neg = torch.clamp(neg_min, max=pos.size(1) - 1)
        neg = idx_rank < num_neg.expand_as(idx_rank)

        # Confidence Loss Including Positive and Negative Examples
        pos_idx = pos.unsqueeze(2).expand_as(conf_data)
        neg_idx = neg.unsqueeze(2).expand_as(conf_data)
        conf_p = conf_data[(pos_idx + neg_idx).gt(0)].view(-1, self.num_classes)
        targets_weighted = conf_t[(pos + neg).gt(0)]
        loss_c = F.cross_entropy(conf_p, targets_weighted, size_average=False)

        # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N

        N = max(num_pos.data.sum().float(), 1)
        loss_l /= N
        loss_c /= N
        return loss_l, loss_c

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章