在深度學習中,argmax函數很常見,標籤採用ont-hot,輸出層用softmax激活可以加快網絡訓練速度和提升準確率,而在取值時,用argmax函數取值,涉及軸問題,數據在哪個軸,axis就填哪個軸,通常數據都在最後一個軸,下面舉例說明。
一維情形:
代碼:
import tensorflow as tf
import numpy as np
np_arry = (np.random.normal(size=[10]))
tf_arry=tf.Variable(np_arry)
init=tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print(sess.run(tf.argmax(tf_arry,axis=0)))
結果:6
二維:
import tensorflow as tf
import numpy as np
np_arry = (np.random.normal(size=[3,10]))
tf_arry=tf.Variable(np_arry)
init=tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print(sess.run(tf.argmax(tf_arry,axis=1)))
結果:[4 6 4]
注意,二維的時候,3是批次,10是數據,我們如果使axis=0,程序不會報錯,但結果是錯的,結果的意義是批次爲10,數據長度爲3
三維:
import tensorflow as tf
import numpy as np
np_arry = (np.random.normal(size=[3,4,10]))
tf_arry=tf.Variable(np_arry)
init=tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print(sess.run(tf.argmax(tf_arry,axis=2)))
結果:
[[0 2 8 5]
[4 9 1 0]
[7 2 6 4]]
三維一般在RNN中比較常見