Pytorch和Tensorflow中的交叉熵損失函數

原文地址

Pytorch系列目錄

  • 導入支持

    import tensorflow as tf
    import torch
    import numpy as np
    
  • 生成測試數據

    onehot_labels = [[0,0,1,0,0],
                      [0,0,0,1,0],
                      [0,1,0,0,0],
                      [1,0,0,0,0]]
    labels = np.argmax(onehot_labels, axis=1)
    # [2 3 1 0]
    logits = [[-1.1258, -1.1524, -0.2506, -0.4339,  0.5988],
              [-1.5551, -0.3414,  1.8530,  0.4681, -0.1577],
              [ 1.4437,  0.2660,  1.3894,  1.5863,  0.9463],
              [-0.8437,  0.9318,  1.2590,  2.0050,  0.0537]]
    

    labels相當於真實的分類數據,其中onehot_labels是對類別號的標記方式進行的onehot處理;logits是網絡生成的預測數據

  • 在TensorFlow中

    # 轉成tf可以處理的張量
    tflabels = tf.constant(labels)
    tflabels_oh = tf.constant(onehot_labels)
    tflogits = tf.constant(logits)
    
    tfloss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tflabels, logits=tflogits)
    # 可以直接傳入類別標號
    tfloss_oh = tf.nn.softmax_cross_entropy_with_logits(labels=tflabels_oh, logits=tflogits)
    # 需要傳入onehot變化後的類別數據
    
    with tf.Session() as sess:
        lossvalue = sess.run(tfloss)
        loss_oh_value = sess.run(tfloss_oh)
        print('tfloss\t\t', lossvalue)
        print('tfloss_oh\t', loss_oh_value)
    

    其中用到了tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tflabels, logits=tflogits)和tf.nn.softmax_cross_entropy_with_logits(labels=tflabels_oh, logits=tflogits)兩個方法,兩者的主要區別是前者傳入的labels可以是直接的數字類別標記,而後者的傳入onehot化之後的labels

  • 在Pytorch中

    # 轉成pytorch可以識別的張量
    ptlabels = torch.tensor(labels).int()
    ptlogits = torch.tensor(logits)
    
    ptloss = torch.nn.CrossEntropyLoss(reduce=False)(ptlogits, ptlabels.long())
    ptloss2 = torch.nn.NLLLoss(reduce=False)(torch.nn.LogSoftmax(dim=-1)(ptlogits), ptlabels.long())
    print('ptloss:\t\t', ptloss)
    print('ptloss2:\t', ptloss2)
    

    其中用到了兩種方式實現,兩者沒有差別,後者是前者的內部實現方式;其中reduce的作用是是否對結果進一步處理,不過不設定,會默認輸出當前結果的平均值(就只有一個值了)

  • 結果分析

    print結果如下

    tfloss		 [1.6081128 1.8093656 2.5681138 3.5499053]
    tfloss_oh	 [1.6081128 1.8093656 2.5681138 3.5499053]
    ptloss:		 tensor([1.6081, 1.8094, 2.5681, 3.5499])
    ptloss2:	 tensor([1.6081, 1.8094, 2.5681, 3.5499])
    

    可以看到,結果都是一樣的

  • 說明

    TensorFlow版本1.14

  • 參考文獻

    【TensorFlow】關於tf.nn.sparse_softmax_cross_entropy_with_logits()

    pytorch筆記:03)softmax和log_softmax,以及CrossEntropyLoss

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章