基礎理論
argmax中的axis參數表示在該維度上比較各元素。並且,張量各維度對換,不影響在該維度取argmax()的結果。
a = tf.constant([[[1, 2, 3], [3, 2, 2]], [[10, 11, 12], [4, 5, 6]]]) # a是個2*2*3的tensor
b = tf.argmax(a, axis=1, output_type=tf.int32)
at = tf.transpose(a, [0, 2, 1]) # 將DIM1和DIM2對換,at變成了2*3*2
c = tf.argmax(at, axis=2, output_type=tf.int32)
with tf.Session() as sess:
print(sess.run(b))
print(sess.run(c))
print("")
輸出結果
[[1 0 0]
[0 0 0]]
[[1 0 0]
[0 0 0]]
tf.argmax(a, axis=1)相當於是在a的DIM1上比較,也就是1和3,2和2,3和2,以及10和4,11和5,12和6比較。如果改成tf.argmax(a, axis=0),相當於是a在DIM0上比較,也就是1和10,2和11,3和12,以此類推。
應用場景
比如,目前有分子特徵張量input,維度爲SampNum × AtomNum × FeatNum,那麼,argmax(input, axis=1)將得到維度爲SampNum × FeatNum的Tensor,其元素表示各樣本分子的各種向量值表徵、同種向量的最大者所對應的原子id。
同樣的,再來一個,argmax(input, axis=2)將得到維度爲SampNum × AtomNum的Tensor,其元素表示各樣本分子的各原子的FeatNum種特徵中,最大的特徵值所對應的特徵id。