tensorflow筆記 cross entropy loss

交叉熵損失函數是模型中非常常見的一種損失函數,tensorflow中有一個計算交叉熵的函數:tf.nn.sigmoid_cross_entropy_with_logits,也可以調用keras中的函數: tf.keras.backend.binary_crossentropy,需要注意的是兩者的輸入有一些不同。


先來看看tf自帶的sigmoid_cross_entropy_with_logits:

tf.nn.sigmoid_cross_entropy_with_logits(
    _sentinel=None,
    labels=None,
    logits=None,
    name=None
)

sigmoid_cross_entropy_with_logits()需要兩個參數,神經網絡最後一層的輸出logits和真實值labels。內部會經過一次sigmoid再計算cross entropy loss,計算方式如下所示:

令x = logits, z = labels
Loss = - z * log(sigmoid(x)) - (1 - z) * log(1 - sigmoid(x))
= - z * log(1 / (1 + exp(-x))) - (1 - z) * log(exp(-x) / (1 + exp(-x)))
= z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
= z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
= (1 - z) * x + log(1 + exp(-x))
= x - x * z + log(1 + exp(-x))

即後面代碼中的prob_error2式


tf.keras.backend.binary_crossentropy與sigmoid_cross_entropy_with_logits輸入有一些不一樣,因爲keras是已經內部封裝好的函數,所以要求的輸入是神經網絡經過sigmoid後的輸出,binary_crossentropy在內部會先將輸入轉化爲logits,然後再調用tf.nn.sigmoid_cross_entropy_with_logits計算交叉熵。


下面做一個簡單的驗證,注意兩個函數輸入的不同

import numpy as np
import tensorflow as tf

def sigmoid(x):
    return 1.0/(1+np.exp(-x))

labels=np.array([[1.,0.,0.],[0.,1.,0.],[0.,0.,1.]])
logits=np.array([[1.,0.,0.],[0.,1.,0.],[0.,0.,1.]])
label = tf.convert_to_tensor(labels, np.float32)
logit = tf.convert_to_tensor(logits, np.float32)
y_pred=sigmoid(logits)
y_preds = tf.convert_to_tensor(y_pred, np.float32)
prob_error1=-labels*np.log(y_pred)-(1-labels)*np.log(1-y_pred)
prob_error2=-logits*labels+np.log(1+np.exp(logits))
print(prob_error1)
print(prob_error2)


print(".............")
with tf.Session() as sess:
    #print(sess.run(label))
    print(sess.run(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels,logits=logits)))
    #print('='*20)
    #print(sess.run(label))
    print(sess.run(tf.keras.backend.binary_crossentropy(label,y_preds)))
    #print(sess.run(prob_error2))

'''
[[0.31326169 0.69314718 0.69314718]
 [0.69314718 0.31326169 0.69314718]
 [0.69314718 0.69314718 0.31326169]]
[[0.31326169 0.69314718 0.69314718]
 [0.69314718 0.31326169 0.69314718]
 [0.69314718 0.69314718 0.31326169]]
.............
[[0.31326169 0.69314718 0.69314718]
 [0.69314718 0.31326169 0.69314718]
 [0.69314718 0.69314718 0.31326169]]
[[0.3132617 0.6931472 0.6931472]
 [0.6931472 0.3132617 0.6931472]
 [0.6931472 0.6931472 0.3132617]]
'''
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章