多分類問題的soft cross entropy 損失函數

在做多分類問題的時候,分類結果的損失函數經常使用交叉熵損失函數,對於預測結果,先經過softmax,然後經過log,然後再根據One hot向量只取得其中的一個值作爲損失函數的吸收值,比如,logsoftmax後的值爲[-0.6, -0.12 , -0.33, -0.334, -0.783],假設one hot label 爲[ 0,0,0,0,1 ],則損失函數的值爲 Loss = 0.783,,也就是說,只有一個值納入了計算,我就在想,可不可以將所有的值都納入計算呢,如果這樣的話,就得將label轉爲soft label ,爲 [0.2, 0.2 , 0.2, 0.2, 1],將0 Label 的地方設置爲 1/(label.shape[0]),再進行計算損失,則這個對應的損失函數的pytorch實現如下所示:

class softcrossentropy(nn.Module):
    def __init__(self):
        super( softcrossentropy, self ).__init__()
        self.cel = nn.CrossEntropyLoss()
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, inputs, labels):
        loss_tra_cel = self.cel( inputs, labels )

        ls = (-1) * self.logsoftmax(inputs)
        ls_sum = torch.sum(ls, dim=1) / inputs.shape[1]

        loss_soft_cel = torch.sum( ls_sum ) / inputs.shape[0]

        loss = loss_tra_cel + loss_soft_cel
        return loss

     分別利用上面的softccrossentroy和正常的crossentropy作爲分類損失函數損失函數,對market1501進行行人重試別的訓練,記錄每個mini batch的train loss 、train accuracy、val loss 、val accuracy,在訓練過程中將這些數據都保存起來,後期進行比對分析。訓練的數據在這了:鏈接: https://pan.baidu.com/s/18jYFFd9LaPrbd72OdHjdpw  密碼: hg9r

利用訓練保留的數據進行後期分析比對,文件目錄如下所示:

mat.py的代碼如下:

import scipy.io
import matplotlib.pyplot as plt
# train_acc
# train_loss
# val_loss
# val_acc

soft_res = scipy.io.loadmat('soft_mini_batch_data.mat')
soft_train_acc = soft_res['train_acc']

tra_res = scipy.io.loadmat('tra_mini_batch_data.mat')
tra_train_acc = tra_res['train_acc']

soft_acc_list = soft_train_acc[0].tolist()
train_acc_list = tra_train_acc[0].tolist()

data_length = int(len( soft_acc_list )/10)
x_label = [ i for i in range(data_length) ]

fig, ax = plt.subplots()
ax.plot( x_label, soft_acc_list[0:data_length],label = 'soft' )
ax.plot( x_label, train_acc_list[0:data_length], label = 'tra' )
ax.set_xlabel('mini batch num')
ax.set_ylabel('accuracy')
ax.set_title('god bless')
ax.legend()
plt.show()

      這裏比較了在不同的criterion下使用同樣的方法,對同樣的數據進行訓練的過程中的幾個數值的變化情況。這個代碼只比較了每個Mini batch的準確率情況。對比圖如下所示:

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章