從caffe中我們看到softmax有下面這些參數
// Message that stores parameters used by SoftmaxLayer, SoftmaxWithLossLayer
message SoftmaxParameter {
enum Engine {
DEFAULT = 0;
CAFFE = 1;
CUDNN = 2;
}
optional Engine engine = 1 [default = DEFAULT];
// The axis along which to perform the softmax -- may be negative to index
// from the end (e.g., -1 for the last axis).
// Any other axes will be evaluated as independent softmaxes.
optional int32 axis = 2 [default = 1];
}
一般來說axis也不需要修改,默認設爲1,即在c上做計算。
那麼設置不同的axis,結果有什麼不同?我們舉個例子一目瞭然;
import tensorflow as tf
import numpy as np
a = np.array([[1, 2, 3], [1, 2, 3]])
a = tf.cast(a, tf.float32)
#>>> a
#tf.Tensor: shape=(2, 3), dtype=float32, numpy=
#array([[1., 2., 3.],
# [1., 2., 3.]], dtype=float32)>
#
s1 = tf.nn.softmax(a,axis=0)
print(s1)
#tf.Tensor(
#[[0.5 0.5 0.5]
#[0.5 0.5 0.5]], shape=(2, 3), dtype=float32)
s2 = tf.nn.softmax(a,axis=1)
print(s2)
#tf.Tensor(
#[[0.09003057 0.24472848 0.66524094]
#[0.09003057 0.24472848 0.66524094]], shape=(2, 3), dtype=float32)
我們來看看計算過程:
axis = 0時(表示縱軸,方向從上到下)
axis = 1時(表示橫軸,方向從左到右)
再舉一個三維數組深入理解一下
import tensorflow as tf
import numpy as np
a = np.array([[[1, 2, 3], [1, 2, 3]],[[4, 5, 6], [4, 5, 6]]])
a = tf.cast(a, tf.float32)
#>>> a
#<tf.Tensor: shape=(2, 2, 3), dtype=float32, numpy=
#array([[[1., 2., 3.],
# [1., 2., 3.]],
#
# [[4., 5., 6.],
# [4., 5., 6.]]], dtype=float32)>
#
s1 = tf.nn.softmax(a,axis=0)
print(s1)
#tf.Tensor(
#[[[0.04742587 0.04742587 0.04742587]
# [0.04742587 0.04742587 0.04742587]]
#
# [[0.95257413 0.95257413 0.95257413]
# [0.95257413 0.95257413 0.95257413]]], shape=(2, 2, 3), dtype=float32)
s2 = tf.nn.softmax(a,axis=1)
print(s2)
#tf.Tensor(
#[[[0.5 0.5 0.5]
# [0.5 0.5 0.5]]
#
# [[0.5 0.5 0.5]
# [0.5 0.5 0.5]]], shape=(2, 2, 3), dtype=float32)
s3 = tf.nn.softmax(a,axis=2)
print(s3)
#tf.Tensor(
#[[[0.09003057 0.24472848 0.66524094]
# [0.09003057 0.24472848 0.66524094]]
#
# [[0.09003057 0.24472848 0.66524094]
# [0.09003057 0.24472848 0.66524094]]], shape=(2, 2, 3), dtype=float32)
計算過程如下:
axis=0時
。。。再來重複上面3個計算3次
axis=1時(1和2的計算和上面二維差不多)
。。。
axis=2時
…