【機器學習】focalloss原理以及pytorch實現

最近在做一個分類項目,發現很多“難樣本”比較不好處理又特別重要,想試試FocalLoss。沒找到pytorch相關實現,本來想研究pytorch的cross_entropy源碼,稍微改一下(怕手殘自己寫的loss效率比較低),但是發現有點複雜,我的任務比較簡單,改那玩意有點累。
我們知道,對於二分類:
cross_entropy(y,y^)=ylogy^1ylog1y^ cross\_entropy(y,\hat y) = -ylog^{\hat y}-(1-y)log^{1-\hat y}

cross_entropy(y,y^)={logy^y=1log1y^y=0 cross\_entropy(y,\hat y) =\begin{cases} -log^{\hat y}& \text{y=1}\\ -log^{1-\hat y}& \text{y=0} \end{cases}
y^\hat y爲模型預測概率

如果有一個正樣本,模型預測結果爲0.9,loss爲-log(0.9)約等於0.046

還有一個正樣本,模型預測結果爲0.55,loss爲-log(0.55)約等於0.260

這個預測爲0.55的樣本提供的loss是預測爲0.9的樣本的5.65

如果我把公式改成下面這樣:
γ=2FocalLoss(y,y^)={(1y^)γlogy^y=1y^γlog1y^y=0 \gamma=2\\ FocalLoss(y,\hat y) = \begin{cases} -(1-\hat y)^{\gamma}log^{\hat y}& \text{y=1}\\ -\hat y^{\gamma}log^{1-\hat y}& \text{y=0}\\ \end{cases}
這時如果有一個正樣本,模型預測結果爲0.9,loss爲-0.1*0.1*log(0.9)約等於0.00046

還有一個正樣本,模型預測結果爲0.55,loss爲-0.45*0.45*log(0.55)約等於0.0526

這個預測爲0.55的樣本提供的loss是預測爲0.9的樣本的114.35

這樣就可以讓模型更加更加關注“難樣本”

另外還可以給正負樣本的loss添加權重,讓模型更注重正/負樣本

上代碼:

import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    def __init__(self, device, gamma, alpha):
        super(FocalLoss, self).__init__()
        #self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.device = device
        self.gamma = gamma
        self.alpha = alpha
        
    def forward(self, inputs, targets): 
        if self.device == 'cpu':
            # 計算正負樣本權重
            alpha_factor = torch.ones(targets.shape) * self.alpha
            alpha_factor = torch.where(torch.eq(targets, 1), alpha_factor, 1. - alpha_factor)
            # 計算因子項
            focal_weight = torch.where(torch.eq(targets, 1), 1. - inputs, inputs)
            # 得到最終的權重
            focal_weight = alpha_factor * torch.pow(focal_weight, self.gamma)
            targets = targets.type(torch.FloatTensor) 
            # 計算標準交叉熵
            bce = F.binary_cross_entropy(inputs, targets)
            # focal loss 
            cls_loss = focal_weight * bce
        else:
            gpu_targets = targets.cuda()
            gpu_inputs = inputs.cuda()
            alpha_factor = torch.ones(gpu_targets.shape).cuda() * self.alpha
            alpha_factor = torch.where(torch.eq(gpu_targets, 1), alpha_factor, 1. - alpha_factor)
            focal_weight = torch.where(torch.eq(gpu_targets, 1), 1. - gpu_inputs, gpu_inputs)
            focal_weight = alpha_factor * torch.pow(focal_weight, self.gamma)
            targets = targets.type(torch.FloatTensor)
            bce = F.binary_cross_entropy(gpu_inputs, gpu_targets)
            focal_weight = focal_weight.cuda()
            cls_loss = focal_weight * bce

        return cls_loss.sum()

希望能幫助到大家~

pytorch學習筆記 | Focal loss的原理與pytorch實現

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章