OHEM的pytorch代碼實現細節

詳細解讀一下OHEM的實現代碼:

def ohem_loss(
    batch_size, cls_pred, cls_target, loc_pred, loc_target, smooth_l1_sigma=1.0
):
    """
    Arguments:
        batch_size (int): number of sampled rois for bbox head training
        loc_pred (FloatTensor): [R, 4], location of positive rois
        loc_target (FloatTensor): [R, 4], location of positive rois
        pos_mask (FloatTensor): [R], binary mask for sampled positive rois
        cls_pred (FloatTensor): [R, C]
        cls_target (LongTensor): [R]

    Returns:
        cls_loss, loc_loss (FloatTensor)
    """
    ohem_cls_loss = F.cross_entropy(cls_pred, cls_target, reduction='none', ignore_index=-1)
    ohem_loc_loss = smooth_l1_loss(loc_pred, loc_target, sigma=smooth_l1_sigma, reduce=False)
    #這裏先暫存下正常的分類loss和迴歸loss
    loss = ohem_cls_loss + ohem_loc_loss
    #然後對分類和迴歸loss求和

  
    sorted_ohem_loss, idx = torch.sort(loss, descending=True)
    #再對loss進行降序排列
    keep_num = min(sorted_ohem_loss.size()[0], batch_size)
    #得到需要保留的loss數量
    if keep_num < sorted_ohem_loss.size()[0]:
    #這句的作用是如果保留數目小於現有loss總數,則進行篩選保留,否則全部保留
        keep_idx_cuda = idx[:keep_num]
        #保留到需要keep的數目
        ohem_cls_loss = ohem_cls_loss[keep_idx_cuda]
        ohem_loc_loss = ohem_loc_loss[keep_idx_cuda]
        #分類和迴歸保留相同的數目
    cls_loss = ohem_cls_loss.sum() / keep_num
    loc_loss = ohem_loc_loss.sum() / keep_num
    #然後分別對分類和迴歸loss求均值
    return cls_loss, loc_loss

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章