argmax(input, axis=None, name=None, dimension=None, output_type=tf.int64)
Returns the index with the largest value across axes of a tensor. (deprecated arguments)
根據axis的值返回行或者列最大值的下標,axis取值[-2,2)
上代碼
#創建一個2*3的數組,使用隨機種子,保證數據不變
a = tf.Variable(tf.random_normal([2,3],seed = 12345))
#初始化變量
init = tf.global_variables_initializer()
#對每一列進行計算,返回最大值下標
b = tf.argmax(a,0)
#啓動會話層
with tf.Session() as sess:
sess.run(init)
print(sess.run(a))
print(sess.run(b))
輸出結果:
原始數據:
[[ 0.88424665 0.07843047 0.13639879]
[-0.6109575 1.8525681 -1.1506747 ]]
返回最大值下標索引
[0 1 0]
它是如何返回 [ 0 1 0]的呢
計算流程:
如圖,把原始數據使用兩根紅線豎着分割成三份(按列分割),從上到下進行對比(按行對比),即:0.088424665和-0.6109575進行對比,返回最大值的下標,0.07843047和1.8525681進行對比,返回最大值的下標,0.13639879和-1.1506747進行對比,返回最大值的下標。
0.088424665(下標0)大於 -0.6109575(下標1) ,返回0
0.07843047(下標0) 小於 1.8525681(下標1) , 返回1
0.13639879(下標0) 大於 -1.1506747(下標1), 返回0
最終得到結果[ 0 1 0]
接下來axis 設置爲 1
#創建一個2*3的數組,使用隨機種子,保證數據不變
a = tf.Variable(tf.random_normal([2,3],seed = 12345))
#初始化變量
init = tf.global_variables_initializer()
#對每一列進行計算,返回最大值下標
b = tf.argmax(a,1)
#啓動會話層
with tf.Session() as sess:
sess.run(init)
print('原始數據:')
print(sess.run(a))
print('返回最大值下標索引')
print(sess.run(b))
輸出結果:
原始數據:
[[ 0.88424665 0.07843047 0.13639879]
[-0.6109575 1.8525681 -1.1506747 ]]
返回最大值下標索引
[0 1]
計算流程:
如圖,把原始數據使用一根紅線豎着分割成兩份(按行分割),從左到右進行對比(按列對比),即:0.088424665和0.07843047和0.13639879進行對比,返回最大值的下標,-0.6109575和1.8525681和-1.1506747進行對比,返回最大值的下標。
0.88424665(下標0) 0.07843047(下標1) 0.13639879(下標2) ,經過比較0.88424665最大,返回0
-0.6109575(下標0) 1.8525681(下標1) -1.1506747(下標2) ,經過比較1.8525681最大,返回1
最終得到結果[0 1]
總結:
tf.argmax函數根據axis的值進行 行索引或者列索引,axis取值[-2,2),半開半閉區間,
當axis = 0或-2時 (按列分割),(按行對比)
當axis = 1或-1時 (按行分割),(按列對比)