Pytorch和Tensorflow中的交叉熵损失函数

原文地址

Pytorch系列目录

  • 导入支持

    import tensorflow as tf
    import torch
    import numpy as np
    
  • 生成测试数据

    onehot_labels = [[0,0,1,0,0],
                      [0,0,0,1,0],
                      [0,1,0,0,0],
                      [1,0,0,0,0]]
    labels = np.argmax(onehot_labels, axis=1)
    # [2 3 1 0]
    logits = [[-1.1258, -1.1524, -0.2506, -0.4339,  0.5988],
              [-1.5551, -0.3414,  1.8530,  0.4681, -0.1577],
              [ 1.4437,  0.2660,  1.3894,  1.5863,  0.9463],
              [-0.8437,  0.9318,  1.2590,  2.0050,  0.0537]]
    

    labels相当于真实的分类数据,其中onehot_labels是对类别号的标记方式进行的onehot处理;logits是网络生成的预测数据

  • 在TensorFlow中

    # 转成tf可以处理的张量
    tflabels = tf.constant(labels)
    tflabels_oh = tf.constant(onehot_labels)
    tflogits = tf.constant(logits)
    
    tfloss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tflabels, logits=tflogits)
    # 可以直接传入类别标号
    tfloss_oh = tf.nn.softmax_cross_entropy_with_logits(labels=tflabels_oh, logits=tflogits)
    # 需要传入onehot变化后的类别数据
    
    with tf.Session() as sess:
        lossvalue = sess.run(tfloss)
        loss_oh_value = sess.run(tfloss_oh)
        print('tfloss\t\t', lossvalue)
        print('tfloss_oh\t', loss_oh_value)
    

    其中用到了tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tflabels, logits=tflogits)和tf.nn.softmax_cross_entropy_with_logits(labels=tflabels_oh, logits=tflogits)两个方法,两者的主要区别是前者传入的labels可以是直接的数字类别标记,而后者的传入onehot化之后的labels

  • 在Pytorch中

    # 转成pytorch可以识别的张量
    ptlabels = torch.tensor(labels).int()
    ptlogits = torch.tensor(logits)
    
    ptloss = torch.nn.CrossEntropyLoss(reduce=False)(ptlogits, ptlabels.long())
    ptloss2 = torch.nn.NLLLoss(reduce=False)(torch.nn.LogSoftmax(dim=-1)(ptlogits), ptlabels.long())
    print('ptloss:\t\t', ptloss)
    print('ptloss2:\t', ptloss2)
    

    其中用到了两种方式实现,两者没有差别,后者是前者的内部实现方式;其中reduce的作用是是否对结果进一步处理,不过不设定,会默认输出当前结果的平均值(就只有一个值了)

  • 结果分析

    print结果如下

    tfloss		 [1.6081128 1.8093656 2.5681138 3.5499053]
    tfloss_oh	 [1.6081128 1.8093656 2.5681138 3.5499053]
    ptloss:		 tensor([1.6081, 1.8094, 2.5681, 3.5499])
    ptloss2:	 tensor([1.6081, 1.8094, 2.5681, 3.5499])
    

    可以看到,结果都是一样的

  • 说明

    TensorFlow版本1.14

  • 参考文献

    【TensorFlow】关于tf.nn.sparse_softmax_cross_entropy_with_logits()

    pytorch笔记:03)softmax和log_softmax,以及CrossEntropyLoss

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章