pytorch中的StableBCELoss-----深度學習中的數值穩定性問題一例

 一、傳統的BCELoss(二元交叉熵 )

1 交叉熵

先來看下信息論中交叉熵的形式

 

 

交叉熵是用來描述兩個分佈的距離的,神經網絡訓練的目的就是使 g(x) 逼近 p(x)。

2 二元交叉熵(BCE)與多元交叉熵區別

交叉熵損失函數兩種形式:

這兩個都是交叉熵損失函數,但是看起來長的卻有天壤之別。爲什麼同是交叉熵損失函數,長的卻不一樣呢?

因爲這兩個交叉熵損失函數對應不同的最後一層的輸出:
第一個對應的最後一層是softmax,第二個對應的最後一層是sigmoid。

(1)二元交叉熵

BCE即binary cross-entropy,二元交叉熵。

當你執行二元分類任務時,可以選擇該損失函數。如果你使用BCE(二元交叉熵)損失函數,則只需一個輸出節點即可將數據分爲兩類。輸出值應通過sigmoid激活函數,以便輸出在(0-1)範圍內,其公式爲:

例如,你有一個神經網絡,該網絡獲取與大氣有關的數據並預測是否會下雨。如果輸出大於0.5,則網絡將其分類爲會下雨;如果輸出小於0.5,則網絡將其分類爲不會下雨。即概率得分值越大,下雨的機會越大。

 

訓練網絡時,如果標籤是下雨,則輸入網絡的目標值應爲1,否則爲0。

重要的一點是,如果你使用BCE損失函數,則節點的輸出應介於(0-1)之間。這意味着你必須在最終輸出中使用sigmoid激活函數。因爲sigmoid函數可以把任何實數值轉換(0–1)的範圍。(也就是輸出概率值)

(2)多分類交叉熵

當你執行多類分類任務時,可以選擇該損失函數。如果使用CCE(多分類交叉熵)損失函數,則輸出節點的數量必須與這些類相同。最後一層的輸出應該通過softmax激活函數,以便每個節點輸出介於(0-1)之間的概率值。其公式爲:

例如,你有一個神經網絡,它讀取圖像並將其分類爲貓或狗。如果貓節點具有高概率得分,則將圖像分類爲貓,否則分類爲狗。基本上,如果某個類別節點具有最高的概率得分,圖像都將被分類爲該類別。

 

爲了在訓練時提供目標值,你必須對它們進行一次one-hot編碼。如果圖像是貓,則目標向量將爲(1,0),如果圖像是狗,則目標向量將爲(0,1)。基本上,目標向量的大小將與類的數目相同,並且對應於實際類的索引位置將爲1,所有其他的位置都爲零。

 

二、pytorch中BCE代碼

 

class BCELoss(_WeightedLoss):
    __constants__ = ['reduction', 'weight']

    def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean'):
        super(BCELoss, self).__init__(weight, size_average, reduce, reduction)

    def forward(self, input, target):
        return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)

 

三、數值穩定的StableBCELoss代碼

參考《深度學習中的數值穩定性問題一例

(1)改進後的BCELoss代碼

class StableBCELoss(nn.modules.Module):
       def __init__(self):
             super(StableBCELoss, self).__init__()
       def forward(self, input, target):
             neg_abs = - input.abs()
             loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
             return loss.mean()

上面這段代碼中的 forward 函數用來計算所謂的 binary cross entropy,其實也就是邏輯迴歸中的損失函數:

其中 sigmoid 函數定義如下:

 

(2)改進原因-數值穩定性分析

 

既然只是爲了計算各交叉熵而已,爲什麼要搞這麼複雜的兩行呢:

neg_abs = - input.abs()
loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()

原因在於e^{x}x是較大正數的時候會出現溢出,也就是超出計算機能表達的範圍,從而造成計算誤差,而上面這段代碼就能在避免計算較大正數的自然指數的同時依然能正確的計算交叉熵。可以發現上面代碼的指數運算其操作數是非正數。

下面我們檢查下這段代碼的正確性,代碼中的 input 對應L定義中的 x,target 對應 L定義中的y

input.clamp(min=0) 的意思是 max(0,x),下面分四種情況討論這個方法的正確性:

(3)擴展 

在 softmax 的計算中也有指數運算,也存在這樣的數值穩定性問題: 

解決思路也是一樣的,就是避免較大正數的自然指數運算:

max表示所有x_{i}的最大數,這樣的一個變化相當於分子分母同除以一個正數,所以計算結果是不變的,另外每個 x_{i}-max顯然都小於0,於是就不存在大正數的自然指數運算這樣的問題了。

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章