Keras對多維Tensor的argmax()解析

基礎理論

argmax中的axis參數表示在該維度上比較各元素。並且,張量各維度對換,不影響在該維度取argmax()的結果。

a = tf.constant([[[1, 2, 3], [3, 2, 2]], [[10, 11, 12], [4, 5, 6]]])  # a是個2*2*3的tensor
b = tf.argmax(a, axis=1, output_type=tf.int32)
at = tf.transpose(a, [0, 2, 1])  # 將DIM1和DIM2對換,at變成了2*3*2
c = tf.argmax(at, axis=2, output_type=tf.int32)

with tf.Session() as sess:
    print(sess.run(b))
    print(sess.run(c))
print("")

輸出結果

[[1 0 0]
 [0 0 0]]
[[1 0 0]
 [0 0 0]]

tf.argmax(a, axis=1)相當於是在a的DIM1上比較,也就是1和3,2和2,3和2,以及10和4,11和5,12和6比較。如果改成tf.argmax(a, axis=0),相當於是a在DIM0上比較,也就是1和10,2和11,3和12,以此類推。

應用場景

比如,目前有分子特徵張量input,維度爲SampNum × AtomNum × FeatNum,那麼,argmax(input, axis=1)將得到維度爲SampNum × FeatNum的Tensor,其元素表示各樣本分子的各種向量值表徵、同種向量的最大者所對應的原子id。
同樣的,再來一個,argmax(input, axis=2)將得到維度爲SampNum × AtomNum的Tensor,其元素表示各樣本分子的各原子的FeatNum種特徵中,最大的特徵值所對應的特徵id。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章