Focal_loss
作者是一名深度學習工程師,主要研究計算機視覺與三維點雲處理,作者Github:Github.歡迎多多交流.
同時作者也是pytorch中文網的維護者之一,網站會不定時更新深度學習相關文章,歡迎大家多多支持.
retinanet是ICCV2017的Best Student Paper Award(最佳學生論文),何凱明是其作者之一.文章中最爲精華的部分就是損失函數 Focal loss的提出.
論文中提出類別失衡是造成two-stage與one-stage模型精確度差異的原因.並提出了Focal loss損失函數,通過調整類間平衡因子與難易度平衡因子.最終使one-stage模型達到了two-stage的精確度.
本項目基於pytorch實現focal loss,力圖給你原生pytorch損失函數的使用體驗.
一. 項目簡介
實現過程簡易明瞭,全中文備註.
阿爾法α 參數用於調整類別權重
伽馬γ 參數用於調整不同檢測難易樣本的權重,讓模型快速關注於困難樣本
完整項目地址:Github,歡迎star, fork. github還有其他視覺相關項目
github連接較慢的,可以去Gitee(國內的代碼託管網站),也有完整項目.
項目配有 Jupyter-Notebook 作爲focal loss使用例子.
二. 損失函數公式
focal loss 損失函數基於交叉熵損失函數,在交叉熵的基礎上,引入了α與γ兩個不同的調整因子.
2.1 交叉熵損失
2.2 帶平衡因子的交叉熵
2.3 Focal loss損失
加入 (1-pt)γ 平衡難易樣本的權重,通過γ縮放因子調整,retainnet默認γ=2
2.4 帶平衡因子的Focal損失
論文中最終爲帶平衡因子的focal loss, 本項目實現的也是這個版本
三. Focal loss實現
# -*- coding: utf-8 -*-
# @Author : LG
from torch import nn
import torch
from torch.nn import functional as F
class focal_loss(nn.Module):
def __init__(self, alpha=0.25, gamma=2, num_classes = 3, size_average=True):
"""
focal_loss損失函數, -α(1-yi)**γ *ce_loss(xi,yi)
步驟詳細的實現了 focal_loss損失函數.
:param alpha: 阿爾法α,類別權重. 當α是列表時,爲各類別權重,當α爲常數時,類別權重爲[α, 1-α, 1-α, ....],常用於 目標檢測算法中抑制背景類 , retainnet中設置爲0.25
:param gamma: 伽馬γ,難易樣本調節參數. retainnet中設置爲2
:param num_classes: 類別數量
:param size_average: 損失計算方式,默認取均值
"""
super(focal_loss,self).__init__()
self.size_average = size_average
if isinstance(alpha,list):
assert len(alpha)==num_classes # α可以以list方式輸入,size:[num_classes] 用於對不同類別精細地賦予權重
print("Focal_loss alpha = {}, 將對每一類權重進行精細化賦值".format(alpha))
self.alpha = torch.Tensor(alpha)
else:
assert alpha<1 #如果α爲一個常數,則降低第一類的影響,在目標檢測中爲第一類
print(" --- Focal_loss alpha = {} ,將對背景類進行衰減,請在目標檢測任務中使用 --- ".format(alpha))
self.alpha = torch.zeros(num_classes)
self.alpha[0] += alpha
self.alpha[1:] += (1-alpha) # α 最終爲 [ α, 1-α, 1-α, 1-α, 1-α, ...] size:[num_classes]
self.gamma = gamma
def forward(self, preds, labels):
"""
focal_loss損失計算
:param preds: 預測類別. size:[B,N,C] or [B,C] 分別對應與檢測與分類任務, B 批次, N檢測框數, C類別數
:param labels: 實際類別. size:[B,N] or [B]
:return:
"""
# assert preds.dim()==2 and labels.dim()==1
preds = preds.view(-1,preds.size(-1))
self.alpha = self.alpha.to(preds.device)
preds_softmax = F.softmax(preds, dim=1) # 這裏並沒有直接使用log_softmax, 因爲後面會用到softmax的結果(當然你也可以使用log_softmax,然後進行exp操作)
preds_logsoft = torch.log(preds_softmax)
preds_softmax = preds_softmax.gather(1,labels.view(-1,1)) # 這部分實現nll_loss ( crossempty = log_softmax + nll )
preds_logsoft = preds_logsoft.gather(1,labels.view(-1,1))
self.alpha = self.alpha.gather(0,labels.view(-1))
loss = -torch.mul(torch.pow((1-preds_softmax), self.gamma), preds_logsoft) # torch.pow((1-preds_softmax), self.gamma) 爲focal loss中 (1-pt)**γ
loss = torch.mul(self.alpha, loss.t())
if self.size_average:
loss = loss.mean()
else:
loss = loss.sum()
return loss
詳細的使用例子請到Github查看jupyter-notebook.
說明
完整項目地址:Github,歡迎star, fork.
僅限用於交流學習,如需引用,請聯繫作者.