tf.argmax()函數

argmax(input, axis=None, name=None, dimension=None, output_type=tf.int64)
    Returns the index with the largest value across axes of a tensor. (deprecated arguments)

根據axis的值返回行或者列最大值的下標,axis取值[-2,2)

上代碼

#創建一個2*3的數組,使用隨機種子,保證數據不變
a = tf.Variable(tf.random_normal([2,3],seed = 12345))
#初始化變量
init = tf.global_variables_initializer()
#對每一列進行計算,返回最大值下標
b = tf.argmax(a,0)
#啓動會話層
with tf.Session() as sess:
    sess.run(init)
    print(sess.run(a))
    print(sess.run(b))

輸出結果:

原始數據:
[[ 0.88424665  0.07843047  0.13639879]
 [-0.6109575   1.8525681  -1.1506747 ]]
返回最大值下標索引
[0 1 0]

它是如何返回 [ 0 1 0]的呢

計算流程:

如圖,把原始數據使用兩根紅線豎着分割成三份(按列分割),從上到下進行對比(按行對比),即:0.088424665-0.6109575進行對比,返回最大值的下標,0.078430471.8525681進行對比,返回最大值的下標,0.13639879-1.1506747進行對比,返回最大值的下標。

0.088424665(下標0)大於 -0.6109575(下標1) ,返回0 

0.07843047(下標0) 小於 1.8525681(下標1) , 返回1 

0.13639879(下標0) 大於 -1.1506747(下標1), 返回0 

最終得到結果[ 0 1 0]

 

接下來axis 設置爲 1 

#創建一個2*3的數組,使用隨機種子,保證數據不變
a = tf.Variable(tf.random_normal([2,3],seed = 12345))
#初始化變量
init = tf.global_variables_initializer()
#對每一列進行計算,返回最大值下標
b = tf.argmax(a,1)
#啓動會話層
with tf.Session() as sess:
    sess.run(init)
    print('原始數據:')
    print(sess.run(a))
    print('返回最大值下標索引')
    print(sess.run(b))

輸出結果:

原始數據:
[[ 0.88424665  0.07843047  0.13639879]
 [-0.6109575   1.8525681  -1.1506747 ]]
返回最大值下標索引
[0 1]

計算流程:

如圖,把原始數據使用一根紅線豎着分割成兩份(按行分割),從左到右進行對比(按列對比),即:0.0884246650.078430470.13639879進行對比,返回最大值的下標,-0.61095751.8525681-1.1506747進行對比,返回最大值的下標。

 0.88424665(下標0)  0.07843047(下標1)   0.13639879(下標2) ,經過比較0.88424665最大,返回0

-0.6109575(下標0)   1.8525681(下標1)   -1.1506747(下標2) ,經過比較1.8525681最大,返回1

最終得到結果[0 1]

 

總結:

tf.argmax函數根據axis的值進行 行索引或者列索引,axis取值[-2,2),半開半閉區間,

當axis = 0或-2時 (按列分割),(按行對比)

當axis = 1或-1時 (按行分割),(按列對比)

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章