import tensorflow as tf
tensor_a = tf.Variable([[1,2,3],[4,5,6],[7,8,9]])
tensor_b = tf.Variable([[0],[1],[2]],dtype=tf.int32)
tensor_c = tf.Variable([[0],[0],[0]],dtype=tf.int32)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(tf.batch_gather(tensor_a,tensor_b)))
print(sess.run(tf.batch_gather(tensor_a,tensor_c)))
如上代碼,輸出結果爲
[[1]
[5]
[9]]
[[1]
[4]
[7]]
簡單的說,batch_gather就是通過索引來獲取數組的值。gather怎麼理解?目前我還想不出來。