argmax函數軸取值

在深度學習中,argmax函數很常見,標籤採用ont-hot,輸出層用softmax激活可以加快網絡訓練速度和提升準確率,而在取值時,用argmax函數取值,涉及軸問題,數據在哪個軸,axis就填哪個軸,通常數據都在最後一個軸,下面舉例說明。
一維情形:
代碼:

import tensorflow as tf
import numpy as np

np_arry = (np.random.normal(size=[10]))

tf_arry=tf.Variable(np_arry)

init=tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    print(sess.run(tf.argmax(tf_arry,axis=0)))

結果:6

二維:

import tensorflow as tf
import numpy as np

np_arry = (np.random.normal(size=[3,10]))

tf_arry=tf.Variable(np_arry)

init=tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    print(sess.run(tf.argmax(tf_arry,axis=1)))

結果:[4 6 4]
注意,二維的時候,3是批次,10是數據,我們如果使axis=0,程序不會報錯,但結果是錯的,結果的意義是批次爲10,數據長度爲3
三維:

import tensorflow as tf
import numpy as np

np_arry = (np.random.normal(size=[3,4,10]))

tf_arry=tf.Variable(np_arry)

init=tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    print(sess.run(tf.argmax(tf_arry,axis=2)))

結果:
[[0 2 8 5]
[4 9 1 0]
[7 2 6 4]]
三維一般在RNN中比較常見

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章