pytorch學習筆記 | Focal loss的原理與pytorch實現

Focal 原理簡述

Focal loss是一個針對單階段物體檢測任務中正負樣本不均衡而提出來的損失函數,論文地址來自arxiv

數學定義

先放focal loss(FL)和cross entropy(CE)兩個函數的數學定義。
其中 p 爲概率,而 y 爲 0 或 1 的標籤。

可以看到focal loss的設計很簡單明瞭,就是在標準交叉熵損失函數的引入一個因子(1pt)λ(1 - p_t)^\lambdaλ=0\lambda= 0時,損失函數就是標準交叉熵 。

損失函數意義

focal loss 稱爲焦點損失函數,通過改進標準的二元交叉熵損失函數來控制對正負樣本的訓練,爲了解決在one-stage目標檢測中正負樣本嚴重不均衡的一種策略。該損失函數的設計思想類似於boosting,降低容易分類的樣本對損失函數的影響,注重較難分類的樣本的訓練。

在常規的交叉熵函數的基礎上,添加一個係數項,其影響從下圖曲線來看可知:

  • 當樣本的預測分數較高(ptp_t較大,指的是模型判斷正確的概率較大)時,其計算所得的loss將變小,這一部分樣本視爲分類較好的數據,我們降低其在總體損失值中的比重;
  • 較難訓練的樣本則計算得到更大的loss值,模型將着重針對這些樣本進行訓練和梯度更新。

    進一步探討,當我們考慮類別的比重不相同時,我們可以給各個類別添加一個權重常數α\alpha,比如想使正樣本初始權重爲0.8,負樣本就爲0.2,那麼可以令α=0.8\alpha = 0.8,然後該權重常數乘以對應類別的交叉熵計算中得以生效。這樣就能夠平衡正負樣本的重要性。但是要解決簡單分類和困難分類樣本的問題則需要依賴 λ, λ越大,損失值計算結果越小,這能夠實現對容易樣本降低權重的平滑調節。對於物體檢測,實驗發現 λ=2時最優。
    個人認爲該損失函數的設計思想可以應用於其他同樣有樣本不均衡特點的分類任務。

Pytorch 實現

實現思想很簡單,就是先利用input和target計算出因子項(1pt)λ(1 - p_t)^\lambda,然後乘以標準交叉熵即可。

import torch
import torch.nn as nn
import torch.nn.functional as F 
import config as cfg 
from torch.autograd import Variable

class FocalLoss(nn.Module):
    def __init__(self,):
        super(FocalLoss, self).__init__()
        self.device = torch.device("cuda:" + str(cfg.DEVICE_ID) if torch.cuda.is_available() else "cpu")

    def forward(self, inputs, targets,gamma=2, focal_loss_alpha=0.8):        
    	# 計算正負樣本權重
        alpha_factor = torch.ones(targets.shape) * focal_loss_alpha
        alpha_factor = torch.where(torch.eq(targets, 1), alpha_factor, 1. - alpha_factor)
        # 計算因子項
        focal_weight = torch.where(torch.eq(targets, 1), 1. - inputs, inputs)
        # 得到最終的權重
        focal_weight = alpha_factor * torch.pow(focal_weight, focal_loss_gamma)
        targets = targets.type(torch.FloatTensor) 
        # 計算標準交叉熵
        bce = -(targets * torch.log(inputs) + (1. - targets) * torch.log(1. - inputs))
        # focal loss 
        cls_loss = focal_weight * bce
        return cls_loss.sum()

如果你要在GPU上跑,那麼你可以嘗試以下代碼。

import torch
import torch.nn as nn
import torch.nn.functional as F 
import config as cfg 
from torch.autograd import Variable

class FocalLoss(nn.Module):
    def __init__(self,):
        super(FocalLoss, self).__init__()
        self.device = torch.device("cuda:" + str(cfg.DEVICE_ID) if torch.cuda.is_available() else "cpu")

    def forward(self, inputs, targets):        
        gpu_targets = targets.cuda()
        alpha_factor = torch.ones(gpu_targets.shape).cuda() * cfg.focal_loss_alpha
        alpha_factor = torch.where(torch.eq(gpu_targets, 1), alpha_factor, 1. - alpha_factor)
        focal_weight = torch.where(torch.eq(gpu_targets, 1), 1. - inputs, inputs)
        focal_weight = alpha_factor * torch.pow(focal_weight, cfg.focal_loss_gamma)
        targets = targets.type(torch.FloatTensor)
        inputs = inputs.cuda()
        targets = targets.cuda()
        bce = F.binary_cross_entropy(inputs, targets)
        focal_weight = focal_weight.cuda()
        cls_loss = focal_weight * bce
        return cls_loss.sum()
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章