4 損失函數-庖丁解牛之pytorch

基類定義

pytorch損失類也是模塊的派生,損失類的基類是_Loss,定義如下

class _Loss(Module):
    def __init__(self, size_average=None, reduce=None, reduction='elementwise_mean'):
        super(_Loss, self).__init__()
        if size_average is not None or reduce is not None:
            self.reduction = _Reduction.legacy_get_string(size_average, reduce)
        else:
            self.reduction = reduction

看這個類,有兩點我們知道:

  • 損失類是模塊
  • 不改變forward函數,但是具備執行功能
    還有其他模塊的性質

子類介紹

從_Loss派生的類有

名稱 說明 公式
_WeightedLoss 這個類只是申請了一個權重空間,功能和_Loss一樣
L1Loss X、Y可以是任意形狀的輸入,X與Y的 shape相同
PoissonNLLLoss 適合多目標分類
KLDivLoss 適用於連續分佈的距離計算
MSELoss 均方差
BCEWithLogitsLoss 多目標不需要經過sigmoid
HingeEmbeddingLoss Y中的元素只能爲1或-1 適用於學習非線性embedding、半監督學習。用於計算兩個輸入是否相似
MultiLabelMarginLoss 適用於多目標分類
SmoothL1Loss
SoftMarginLoss
CosineEmbeddingLoss
MarginRankingLoss
TripletMarginLoss

從_WeightedLoss繼續派生的函數有

名稱 說明
NLLLoss
BCELoss
CrossEntropyLoss
MultiLabelSoftMarginLoss
MultiMarginLoss
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章