Focal loss 基於pytorch實現.

Focal_loss

作者是一名深度學習工程師,主要研究計算機視覺與三維點雲處理,作者Github:Github.歡迎多多交流.

同時作者也是pytorch中文網的維護者之一,網站會不定時更新深度學習相關文章,歡迎大家多多支持.


retinanet是ICCV2017的Best Student Paper Award(最佳學生論文),何凱明是其作者之一.文章中最爲精華的部分就是損失函數 Focal loss的提出.

論文中提出類別失衡是造成two-stage與one-stage模型精確度差異的原因.並提出了Focal loss損失函數,通過調整類間平衡因子與難易度平衡因子.最終使one-stage模型達到了two-stage的精確度.


本項目基於pytorch實現focal loss,力圖給你原生pytorch損失函數的使用體驗.

一. 項目簡介

實現過程簡易明瞭,全中文備註.

阿爾法α 參數用於調整類別權重

伽馬γ 參數用於調整不同檢測難易樣本的權重,讓模型快速關注於困難樣本

完整項目地址:Github,歡迎star, fork. github還有其他視覺相關項目

github連接較慢的,可以去Gitee(國內的代碼託管網站),也有完整項目.

項目配有 Jupyter-Notebook 作爲focal loss使用例子.

二. 損失函數公式

focal loss 損失函數基於交叉熵損失函數,在交叉熵的基礎上,引入了α與γ兩個不同的調整因子.

2.1 交叉熵損失

2.2 帶平衡因子的交叉熵

2.3 Focal loss損失

加入 (1-pt)γ 平衡難易樣本的權重,通過γ縮放因子調整,retainnet默認γ=2

2.4 帶平衡因子的Focal損失

論文中最終爲帶平衡因子的focal loss, 本項目實現的也是這個版本


三. Focal loss實現

# -*- coding: utf-8 -*-
# @Author  : LG
from torch import nn
import torch
from torch.nn import functional as F

class focal_loss(nn.Module):    
    def __init__(self, alpha=0.25, gamma=2, num_classes = 3, size_average=True):
        """
        focal_loss損失函數, -α(1-yi)**γ *ce_loss(xi,yi)      
        步驟詳細的實現了 focal_loss損失函數.
        :param alpha:   阿爾法α,類別權重.      當α是列表時,爲各類別權重,當α爲常數時,類別權重爲[α, 1-α, 1-α, ....],常用於 目標檢測算法中抑制背景類 , retainnet中設置爲0.25
        :param gamma:   伽馬γ,難易樣本調節參數. retainnet中設置爲2
        :param num_classes:     類別數量
        :param size_average:    損失計算方式,默認取均值
        """
        
        super(focal_loss,self).__init__()
        self.size_average = size_average
        if isinstance(alpha,list):
            assert len(alpha)==num_classes   # α可以以list方式輸入,size:[num_classes] 用於對不同類別精細地賦予權重
            print("Focal_loss alpha = {}, 將對每一類權重進行精細化賦值".format(alpha))
            self.alpha = torch.Tensor(alpha)
        else:
            assert alpha<1   #如果α爲一個常數,則降低第一類的影響,在目標檢測中爲第一類
            print(" --- Focal_loss alpha = {} ,將對背景類進行衰減,請在目標檢測任務中使用 --- ".format(alpha))
            self.alpha = torch.zeros(num_classes)
            self.alpha[0] += alpha
            self.alpha[1:] += (1-alpha) # α 最終爲 [ α, 1-α, 1-α, 1-α, 1-α, ...] size:[num_classes]
        self.gamma = gamma
        
    def forward(self, preds, labels):
        """
        focal_loss損失計算        
        :param preds:   預測類別. size:[B,N,C] or [B,C]    分別對應與檢測與分類任務, B 批次, N檢測框數, C類別數        
        :param labels:  實際類別. size:[B,N] or [B]        
        :return:
        """        
        # assert preds.dim()==2 and labels.dim()==1        
        preds = preds.view(-1,preds.size(-1))        
        self.alpha = self.alpha.to(preds.device)        
        preds_softmax = F.softmax(preds, dim=1) # 這裏並沒有直接使用log_softmax, 因爲後面會用到softmax的結果(當然你也可以使用log_softmax,然後進行exp操作)        
        preds_logsoft = torch.log(preds_softmax)
        preds_softmax = preds_softmax.gather(1,labels.view(-1,1))   # 這部分實現nll_loss ( crossempty = log_softmax + nll )        
        preds_logsoft = preds_logsoft.gather(1,labels.view(-1,1))        
        self.alpha = self.alpha.gather(0,labels.view(-1))        
        loss = -torch.mul(torch.pow((1-preds_softmax), self.gamma), preds_logsoft)  # torch.pow((1-preds_softmax), self.gamma) 爲focal loss中 (1-pt)**γ
        loss = torch.mul(self.alpha, loss.t())        
        if self.size_average:        
            loss = loss.mean()        
        else:            
            loss = loss.sum()        
        return loss

詳細的使用例子請到Github查看jupyter-notebook.

說明

完整項目地址:Github,歡迎star, fork.

僅限用於交流學習,如需引用,請聯繫作者.

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章