TensorFlow內置交叉熵損失函數

今天講解Tensorflow內置4中交叉熵損失函數:

  1. tf.nn.sigmoid_cross_entropy_with_logits
  2. tf.nn.softmax_cross_entropy_with_logits_v2
  3. tf.nn.sparse_softmax_cross_entropy_with_logits
  4. tf.nn.weighted_cross_entropy_with_logits

1. tf.nn.sigmoid_cross_entropy_with_logits

寫在前面:這個損失函數要求 logits/labels 類型爲 float32或float64,因爲在使用時不要將 labels 定義成了 int 型!

tf.nn.sigmoid_cross_entropy_with_logits(
    _sentinel=None,
    labels=None,
    logits=None,
    name=None
)

看看這個損失函數應用於什麼場景?

Measures the probability error in discrete classification tasks in which each class is independent and not mutually exclusive. For instance, one could perform multilabel classification where a picture can contain both an elephant and a dog at the same time.

這個損失函數計算的是 概率誤差,各個類別相互獨立但不必相互排斥。(注:logits 表示未歸一化處理的概率, 即網絡輸出層的輸出結果,因爲損失函數自己會先用Sigmoid/Softmax進行歸一化,對此請參見這篇博客
For brevity, let x = logits, z = labels. The logistic loss is
zlog(sigmoid(x))+(1z)log(1sigmoid(x))=zlog(1/(1+exp(x)))+(1z)log(exp(x)/(1+exp(x)))=zlog(1+exp(x))+(1z)(log(exp(x))+log(1+exp(x)))=zlog(1+exp(x))+(1z)(x+log(1+exp(x))=(1z)x+log(1+exp(x))=xxz+log(1+exp(x)) \begin{aligned} &z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) \\ = &z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x))) \\ = &z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x))) \\ = &z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x)) \\ = &(1 - z) * x + log(1 + exp(-x)) \\ = &x - x * z + log(1 + exp(-x)) \\ \end{aligned} For x < 0, to avoid overflow in exp(-x), we reformulate the above
xxz+log(1+exp(x))=log(exp(x))xz+log(1+exp(x))=xz+log(1+exp(x)) \begin{aligned} &x - x * z + log(1 + exp(-x)) \\ = &log(exp(x)) - x * z + log(1 + exp(-x)) \\ = &- x * z + log(1 + exp(x)) \end{aligned}


PS:那麼什麼是溢出呢?
定義:當變量的數據類型所提供的位數無法適應某個值時,就會發生溢出(上溢)或下溢。
不妨來看一個例子,假設在一個使用了 2 個字節內存的 short int 類型變量中存儲了以下值:
在這裏插入圖片描述
這是 32 767 的二進制表示,也是能存儲在該數據類型中的最大值。這裏先不講負數如何存儲的細節,只要知道 short int 數據類型既可以存儲正數也可以存儲負數就可以了。高階位(即最左側位)是 0 的數字被解釋爲正數,高階位爲 1 的數字則被解釋爲負數。
如果上面示例中存儲的數字加 1,則該變量將變成以下位模式:
在這裏插入圖片描述
但這不是 32 768。相反,它被解釋爲負數,所以這不是預期的結果。二進制 1 已經“流入”到高階位的位置,這就是所謂的溢出(上溢)。
同樣地,當一個整數變量保存的數值在其數據類型負值範圍的最遠端(即最小負值),那麼當它被減去 1 時,其高位中的 1 將變爲 0,結果數將被解釋爲正數。這是溢出的另一個例子。

除了溢出以外,浮點值還會遇到下溢的情況。當一個值太接近於零時,就可能會發生這種問題,過小的數字需要更多數位的精度來表示它,因而無法存儲在保存它的變量中。
簡而言之,溢出就是變量數據類型的位數存儲不了給定數據!

x<0x < 0 且非常小時,對於 exe^{-x} 值可能會非常大,造成溢出!


Hence, to ensure stability and avoid overflow, the implementation uses this equivalent formulation.
max(x,0)xz+log(1+exp(abs(x))) max(x, 0) - x * z + log(1 + exp(-abs(x)))
再來看一個運用這個損失函數的具體例子:

import numpy as np
import tensorflow as tf
 
def sigmoid(x):
    return 1.0/(1+np.exp(-x))
 
labels = np.array([[1.,0.,0.],[0.,1.,0.],[0.,0.,1.]])
logits = np.array([[11.,8.,7.],[10.,14.,3.],[1.,2.,4.]])

# 根據API內部源碼計算Loss
# 單目標
y_pred = sigmoid(logits)
prob_error1 = -labels * np.log(y_pred) - (1 - labels) * np.log(1 - y_pred)
print(".................................................................")
print("----------單目標loss: \n", prob_error1)

# 多目標:張圖片可以有多個類別標籤
labels1 = np.array([[0.,1.,0.],[1.,1.,0.],[0.,0.,1.]]) 
logits1 = np.array([[1.,8.,7.],[10.,14.,3.],[1.,2.,4.]])
y_pred1 = sigmoid(logits1)
prob_error2 = -labels1 * np.log(y_pred1) - (1 - labels1) * np.log(1-y_pred1)
print(".................................................................")
print("----------多目標loss: \n", prob_error2)


with tf.Session() as sess:
    # 直接調用API, logits 表示未歸一化處理的概率
    print("***********************************************************************")
    print("----------單目標loss: \n", sess.run(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels,logits=logits)))
    print("***********************************************************************")
    print("----------多目標loss: \n", sess.run(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels1,logits=logits1)))

觀察結果你發現了什麼?並嘗試一下當 x<0x < 0 且非常小時造成溢出是什麼樣的?最後用上面的優化公式解決溢出問題。

2. tf.nn.softmax_cross_entropy_with_logits_v2

tf.nn.softmax_cross_entropy_with_logits_v2(
    labels,
    logits,
    axis=None,
    name=None,
    dim=None
)

看看這個損失函數應用於什麼場景?

Measures the probability error in discrete classification tasks in which the classes are mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is labeled with one and only one label: an image can be a dog or a truck, but not both.

NOTE: While the classes are mutually exclusive, their probabilities need not be. All that is required is that each row of labels is a valid probability distribution. If they are not, the computation of the gradient will be incorrect.

這個損失函數只適用於單目標的二分類或多分類問題,即一張圖片只能有一個類別標籤,而tf.nn.sigmoid_cross_entropy_with_logits一張圖片可以有多個類別標籤。另外,有效概率分佈是指所有的類別是互斥的,但它們對應的概率不須如此。
注意:

  1. tf.nn.sparse_softmax_cross_entropy_with_logits要求概率有且只有一個類別。
  2. 該 op 內部對 logits 有 softmax 處理,效率更高,因此其輸入需要未歸一化的 logits。 即不需使用 softmax 的輸出, 否則結果會不正確。
  3. tf.nn.softmax_cross_entropy_with_logits 反向傳播只會發生在 logits中;tf.nn.softmax_cross_entropy_with_logits_v2 反向傳播將發生在 logits 和 labels 中。 如果要禁止反向傳播到 labels 中,請先將 labels 張量傳遞一個tf.stop_gradient參數,然後再將其傳遞給此函數。

3. tf.nn.sparse_softmax_cross_entropy_with_logits

tf.nn.sparse_softmax_cross_entropy_with_logits(
    _sentinel=None,
    labels=None,
    logits=None,
    name=None
)

看看這個損失函數應用於什麼場景?

Measures the probability error in discrete classification tasks in which the classes are mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is labeled with one and only one label: an image can be a dog or a truck, but not both.

NOTE: For this operation, the probability of a given label is considered exclusive. That is, soft classes are not allowed, and the labels vector must provide a single specific index for the true class for each row of logits (each minibatch entry). For soft softmax classification with a probability distribution for each entry, see softmax_cross_entropy_with_logits_v2.

這個損失函數與 tf.nn.softmax_cross_entropy_with_logits_v2 基本一致,不同之處在於:tf.nn.sparse_softmax_cross_entropy_with_logits 給定 label 對應的概率也必須是互斥的,即 labels向量 只能在一個特定位置表示真實類別。

4. tf.nn.weighted_cross_entropy_with_logits

tf.nn.softmax_cross_entropy_with_logits_v2(
    labels,
    logits,
    axis=None,
    name=None,
    dim=None
)

看看這個損失函數應用於什麼場景?

This is like sigmoid_cross_entropy_with_logits() except that pos_weight, allows one to trade off recall and precision by up- or down-weighting the cost of a positive error relative to a negative error. The usual cross-entropy cost is defined as: labels * -log(sigmoid(logits)) + (1 - labels) * -log(1 - sigmoid(logits)) .
A value pos_weights > 1 decreases the false negative count, hence increasing the recall. Conversely setting pos_weights < 1 decreases the false positive count and increases the precision.


通常的交叉熵成本定義爲:
targetslog(sigmoid(logits))+(1targets)log(1sigmoid(logits)) targets * -log(sigmoid(logits)) + (1 - targets) * -log(1 - sigmoid(logits)) pos_weight是作爲損失表達式中的正目標項的乘法系數引入的:
targetslog(sigmoid(logits))pos_weight+(1targets)log(1sigmoid(logits)) targets * -log(sigmoid(logits)) * pos\_weight + (1 - targets) * -log(1 - sigmoid(logits))


其實這個損失函數類似 sigmoid_cross_entropy_with_logits(),因此也是用於解決二分類問題的。與 sigmoid_cross_entropy_with_logits() 的區別就在於這個損失函數添加了一個權重參數,用於調節正樣本損失的比例,顯示這是針對正負樣本不均衡時提出的方法。

For brevity, let x = logits, z = labels, q = pos_weight. The loss is:
qzlog(sigmoid(x))+(1z)log(1sigmoid(x))=qzlog(1/(1+exp(x)))+(1z)log(exp(x)/(1+exp(x)))=qzlog(1+exp(x))+(1z)(log(exp(x))+log(1+exp(x)))=qzlog(1+exp(x))+(1z)(x+log(1+exp(x))=(1z)x+(qz+1z)log(1+exp(x))=(1z)x+(1+(q1)z)log(1+exp(x)) \begin{aligned} &qz * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) \\ = &qz * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x))) \\ = &qz * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x))) \\ = &qz * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x)) \\ = &(1 - z) * x + (qz + 1 - z) * log(1 + exp(-x)) \\ = &(1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x)) \\ \end{aligned} Setting l = (1 + (q - 1) * z), to ensure stability and avoid overflow, the implementation uses:
(1z)x+l(log(1+exp(abs(x)))+max(x,0)) (1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))

參考:TF官網API介紹

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章