softmax中axis參數

從caffe中我們看到softmax有下面這些參數

// Message that stores parameters used by SoftmaxLayer, SoftmaxWithLossLayer
message SoftmaxParameter {
  enum Engine {
    DEFAULT = 0;
    CAFFE = 1;
    CUDNN = 2;
  }
  optional Engine engine = 1 [default = DEFAULT];

  // The axis along which to perform the softmax -- may be negative to index
  // from the end (e.g., -1 for the last axis).
  // Any other axes will be evaluated as independent softmaxes.
  optional int32 axis = 2 [default = 1];
}

一般來說axis也不需要修改,默認設爲1,即在c上做計算。
那麼設置不同的axis,結果有什麼不同?我們舉個例子一目瞭然;

import tensorflow as tf
import numpy as np

a = np.array([[1, 2, 3], [1, 2, 3]])
a = tf.cast(a, tf.float32)
#>>> a
#tf.Tensor: shape=(2, 3), dtype=float32, numpy=
#array([[1., 2., 3.],
#       [1., 2., 3.]], dtype=float32)>
# 

s1 = tf.nn.softmax(a,axis=0)
print(s1)
#tf.Tensor(
#[[0.5 0.5 0.5]
#[0.5 0.5 0.5]], shape=(2, 3), dtype=float32)

s2 = tf.nn.softmax(a,axis=1)
print(s2)
#tf.Tensor(
#[[0.09003057 0.24472848 0.66524094]
#[0.09003057 0.24472848 0.66524094]], shape=(2, 3), dtype=float32)


我們來看看計算過程:
axis = 0時(表示縱軸,方向從上到下)
在這裏插入圖片描述
在這裏插入圖片描述
在這裏插入圖片描述

axis = 1時(表示橫軸,方向從左到右)
在這裏插入圖片描述
在這裏插入圖片描述
在這裏插入圖片描述
再舉一個三維數組深入理解一下

import tensorflow as tf
import numpy as np

a = np.array([[[1, 2, 3], [1, 2, 3]],[[4, 5, 6], [4, 5, 6]]])
a = tf.cast(a, tf.float32)
#>>> a
#<tf.Tensor: shape=(2, 2, 3), dtype=float32, numpy=
#array([[[1., 2., 3.],
#        [1., 2., 3.]],
#
#       [[4., 5., 6.],
#        [4., 5., 6.]]], dtype=float32)>
#
s1 = tf.nn.softmax(a,axis=0)
print(s1)
#tf.Tensor(
#[[[0.04742587 0.04742587 0.04742587]
#  [0.04742587 0.04742587 0.04742587]]
#
# [[0.95257413 0.95257413 0.95257413]
#  [0.95257413 0.95257413 0.95257413]]], shape=(2, 2, 3), dtype=float32)

s2 = tf.nn.softmax(a,axis=1)
print(s2)
#tf.Tensor(
#[[[0.5 0.5 0.5]
#  [0.5 0.5 0.5]]
#
# [[0.5 0.5 0.5]
#  [0.5 0.5 0.5]]], shape=(2, 2, 3), dtype=float32)

s3 = tf.nn.softmax(a,axis=2)
print(s3)
#tf.Tensor(
#[[[0.09003057 0.24472848 0.66524094]
#  [0.09003057 0.24472848 0.66524094]]
#
# [[0.09003057 0.24472848 0.66524094]
#  [0.09003057 0.24472848 0.66524094]]], shape=(2, 2, 3), dtype=float32)

計算過程如下:
axis=0時
在這裏插入圖片描述
在這裏插入圖片描述
在這裏插入圖片描述
。。。再來重複上面3個計算3次

axis=1時(1和2的計算和上面二維差不多)
在這裏插入圖片描述
在這裏插入圖片描述
在這裏插入圖片描述
。。。

在這裏插入圖片描述
axis=2時
在這裏插入圖片描述
在這裏插入圖片描述
在這裏插入圖片描述
在這裏插入圖片描述

在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章