圖像分割中涉及的損失函數(主要用來處理樣本不平衡)

前言

圖像分割中的loss函數繼承了深度學習模型中一般損失函數的所有特點,但是又有其自身的特點,即需要處理類別不平衡的問題,在進行圖像分割中,往往圖像中成爲背景的像素值佔絕大多數,而前景目標的像素只有很少一部分。
注:以下鏈接詳細介紹了深度學習模型中的一般損失函數。

參考鏈接:https://blog.csdn.net/weixin_38410551/article/details/104973011

如圖所示:
車道線實例圖

車道線分割圖
注:車道線只是佔很少一部分,大部分爲背景。

圖像分割處理的時候,經常要遇到這樣的樣本不均衡問題

解決辦法

主流的解決辦法,都是通過減少樣本中樣本數較多的類別損失函數權重,增加樣本中樣本數較少的類別損失函數的權重。這樣,預測樣本較少的類別,損失函數下降的更快,而預測樣本較多的類別,損失函數下降得慢。
注:防止過擬合,也採取了類似的辦法,通過增加權重參數的二次項,在優化損失函數的過程中,來降低權重的大小,從而達到防止過擬合的目的。(過擬合:某些權重參數過大

損失函數

1. log loss損失函數

參考鏈接:https://blog.csdn.net/weixin_38410551/article/details/104973011

2. WBE loss

思想:樣本數目較多的類別,要減小權重,樣本數目較少的類別,要增加權重。
公式:WCE=1Nn=1Nwrnlog(pn)+(1rn)log(1pn)WCE=−\frac{1}{N}\sum_{n = 1}^{N}wr_{n}​log(p_{n})+(1−r_{n}​)log(1−p_{n}​)
w=Nnpnnpnw = -\frac{N - \sum_{n}p_{n}}{\sum_{n}p_{n}}
其中,ww爲權重。
缺點:需要人爲去調整權重。

3. Focal loss

應用場景

應用於目標檢測的二分類問題。

思想

與WB loss思想相同,但是這裏的參數,可以自動化調節。

公式

1Ni=1N(αyi(1pi)γlog(pi)+(1α)(1yi)piγlog(1pi))−\frac{1}{N}\sum_{i = 1}^{N}(\alpha y_{i}(1 - p_{i})^{\gamma}​log(p_{i})+(1−\alpha)(1-y_{i})p_{i}^{\gamma}log(1−p_{i}​))
其基本思想就是,對於類別極度不均衡的情況下,網絡如果在log loss下會傾向於只預測負樣本,並且負樣本的預測概率pipi p_ipi​也會非常的高,回傳的梯度也很大。但是如果添加(1−pi)γ(1−pi)γ (1-p_i)^{\gamma}(1−pi​)γ則會使預測概率大的樣本得到的loss變小,而預測概率小的樣本,loss變得大,從而加強對正樣本的關注度。
可以改善目標不均衡的現象,對此情況比 binary_crossentropy 要好很多。

Dice loss

Dice 係數

2ABAB\frac{2A\bigcap B}{A\bigcup B}
其中,A爲樣本標籤值,B爲預測值,是一種集合相似度度量函數,通常用來衡量兩個集合的相似度。(取值範圍[0,1])。
分子乘以2, 是因爲分母存在重複計算A和B的重合元素。

Dice 差異函數

很簡單,1減去Dice係數,就是差異函數。Dice係數和Dice差異函數是對同一問題的兩種表述方式。12XYXY1 - \frac{2X\bigcap Y}{X\bigcup Y}

注:dice loss 比較適用於樣本極度不均的情況,一般的情況下,使用 dice loss 會對反向傳播造成不利的影響,容易使訓練變得不穩定.
示例代碼:

class BinaryDiceLoss(nn.Module):
    """Dice loss of binary class
    Args:
        smooth: A float number to smooth loss, and avoid NaN error, default: 1
        p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2
        predict: A tensor of shape [N, *]
        target: A tensor of shape same with predict
        reduction: Reduction method to apply, return mean over batch if 'mean',
            return sum if 'sum', return a tensor of shape [N,] if 'none'
    Returns:
        Loss tensor according to arg reduction
    Raise:
        Exception if unexpected reduction
    """
    def __init__(self, smooth=1, p=2, reduction='mean'):
        super(BinaryDiceLoss, self).__init__()
        self.smooth = smooth #避免分子爲0
        self.p = p #平方值
        self.reduction = reduction

    def forward(self, predict, target):
        assert predict.shape[0] == target.shape[0], "predict & target batch size don't match"
        predict = predict.contiguous().view(predict.shape[0], -1)
        target = target.contiguous().view(target.shape[0], -1)
        #num:分子
        num = 2*torch.sum(torch.mul(predict, target), dim=1) + self.smooth
        #den 分母
        den = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.smooth
        #dice loss
        loss = 1 - num / den

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        elif self.reduction == 'none':
            return loss
        else:
            raise Exception('Unexpected reduction {}'.format(self.reduction))


class DiceLoss(nn.Module):
    """Dice loss, need one hot encode input
    Args:
        weight: An array of shape [num_classes,]
        ignore_index: class index to ignore
        predict: A tensor of shape [N, C, *]
        target: A tensor of same shape with predict
        other args pass to BinaryDiceLoss
    Return:
        same as BinaryDiceLoss
    """
    def __init__(self, weight=None, ignore_index=None, **kwargs):
        super(DiceLoss, self).__init__()
        self.kwargs = kwargs
        self.weight = weight
        self.ignore_index = ignore_index

    def forward(self, predict, target):
        #predict和target兩者形狀相同,否則報錯
        assert predict.shape == target.shape, 'predict & target shape do not match'
        #對BinaryDiceLoss進行實例化
        dice = BinaryDiceLoss(**self.kwargs)
        total_loss = 0
        predict = F.softmax(predict, dim=1)

        for i in range(target.shape[1]):
            if i != self.ignore_index:
                dice_loss = dice(predict[:, i], target[:, i])
                if self.weight is not None:
                    assert self.weight.shape[0] == target.shape[1], \
                        'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0])
                    dice_loss *= self.weights[i]
                total_loss += dice_loss

        return total_loss/target.shape[1]   #求取了一個均值

參考鏈接:https://blog.csdn.net/m0_37477175/article/details/83004746
https://blog.csdn.net/JMU_Ma/article/details/97533768?depth_1-utm_source=distribute.pc_relevant.none-task&utm_source=distribute.pc_relevant.none-task
https://blog.csdn.net/m0_37477175/article/details/83004746#Dice_loss_70

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章