【tensorflow】batch_gather使用方法

import tensorflow as tf
tensor_a = tf.Variable([[1,2,3],[4,5,6],[7,8,9]])
tensor_b = tf.Variable([[0],[1],[2]],dtype=tf.int32)
tensor_c = tf.Variable([[0],[0],[0]],dtype=tf.int32)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(tf.batch_gather(tensor_a,tensor_b)))
    print(sess.run(tf.batch_gather(tensor_a,tensor_c)))

如上代碼,輸出結果爲

[[1]
 [5]
 [9]]
[[1]
 [4]
 [7]]
簡單的說,batch_gather就是通過索引來獲取數組的值。gather怎麼理解?目前我還想不出來。

 

發佈了294 篇原創文章 · 獲贊 107 · 訪問量 49萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章