tf.gather
numpy 支持用 ndarray 索引:
import numpy as np
arr = np.arange(9).reshape(3, 3)
idx = np.array([0, 2])
print(arr[idx])
但 tensorflow 中非 scalar 的 tensor 不可以直接用作下標:
import tensorflow as tf
arr = tf.reshape(tf.range(9), [3, 3])
idx = tf.constant([0, 2])
one = tf.constant(1)
with tf.Session() as sess:
print(sess.run(arr[one])) # 可以
print(sess.run(arr[0:1])) # 可以
print(sess.run(arr[idx])) # 報錯
要實現類似的功能,用 tf.gather
:
import tensorflow as tf
arr = tf.reshape(tf.range(9), [3, 3])
idx = tf.constant([0, 2])
with tf.Session() as sess:
print(sess.run(tf.gather(arr, idx)))
tf.gather_nd
這次的目標是:給出矩陣 和索引向量 ,取各 A[i][b[i]],即 A 的每行都取一個元素,下標由 b[i] 決定。用到tf.gather_nd
。
tf.gather_nd 用元素的「座標」選元素,即傳入的第二個參數indices
是要選的那些元素的座標的序列。例如對於上述的目標,indices 就是各[i, b[i]]
組成的序列。b[i] 已經有,只要補上行座標 i 就行。
Example
- 這裏同時也實現了 tensorflow 的 tensor 隨機索引,即生成隨機索引向量 b,用於索引張量 A 的分量。
import tensorflow as tf
import numpy as np
n = 3
m = 4
# 備選數組
A = tf.constant(np.arange(n * m).reshape(n, m))
# 隨機生成列 id
b = tf.random_uniform([n, 1],
minval=0, maxval=m, # 列 id 範圍:[0, m)
dtype=tf.int32)
# 補上行 id
row_id = tf.range(n, dtype="int32")[:, None] # 形狀:[n, 1]
#print(row_id.shape.as_list())
#print(b.shape.as_list())
# 拼在一起組成完整座標
idx = tf.concat([row_id, b], axis=1)
# 選元素
elem = tf.gather_nd(A, idx)
with tf.Session() as sess:
A, b, r, i, e = sess.run([A, b, row_id, idx, elem])
print("A:\n", A)
print("b:\n", b)
print("row id:\n", r)
print("indices:\n", i)
print("elem:\n", e)
結果:
A:
[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]
b:
[[0]
[1]
[1]]
row id:
[[0]
[1]
[2]]
indices:
[[0 0]
[1 1]
[2 1]]
elem:
[0 5 9]
(0,5,9)
即 ( A[0][0], A[1][1], A[2][1] )
tf.batch_gather
用 tf.batch_gather
和 tf.argsort
/tf.top_k
實現矩陣分行排序。場景是:用 A 的行數據 argsort 得出的 indices 來對 B 行數據排序。
import tensorflow as tf
a = tf.constant([
[1, 0, 3, 2, 5],
[4, 7, 9, 8, 6]
])
b = tf.reshape(tf.range(10, 20), [2, 5])
# k_idx = tf.argsort(ham) # tf 1.12 無 `argsort`…用 top_k 代替
k_val, k_idx = tf.math.top_k(- a, a.shape[1]) # minus for ascending
b_sort = tf.batch_gather(b, k_idx)
with tf.Session() as sess:
a, b, k_idx, b_sort = sess.run([a, b, k_idx, b_sort])
print("a:\n", a)
print("b:\n", b)
print("k_idx:\n", k_idx)
print("b_sort:\n", b_sort)
結果:
a:
[[1 0 3 2 5]
[4 7 9 8 6]]
b:
[[10 11 12 13 14]
[15 16 17 18 19]]
k_idx:
[[1 0 3 2 4]
[0 4 1 3 2]]
b_sort:
[[11 10 13 12 14]
[15 19 16 18 17]]
tf.scatter_nd_update / tf.tensor_scatter_nd_update
將高級索引用於左值。注意 tensorflow 需要對 index 升維,詳見代碼。
in numpy
這個功能對應的 numpy 示例
import numpy as np
a = np.arange(12).reshape(3, 4)
print("before:\n", a)
idx = np.array([0, 2]) # numpy 的 index 不 需要升維
val = np.array([
[11, 12, 13, 14],
[15, 16, 17, 18]
])
a[idx] = val # 左值用高級索引
print("after:\n", a)
輸出:
before:
[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]
after:
[[11 12 13 14]
[ 4 5 6 7]
[15 16 17 18]]
in tensorflow 1.12
用 tf.scatter_nd_update
,見 [7]。
import tensorflow as tf
sess = tf.Session()
a = tf.Variable(tf.reshape(tf.range(12), [3, 4]))
sess.run(tf.global_variables_initializer())
print("before:\n", sess.run(a))
idx = tf.constant([[0], [2]]) # index 升維
val = tf.constant([
[11, 12, 13, 14],
[15, 16, 17, 18]
])
update = a.assign(tf.scatter_nd_update(a, idx, val)) # 左值用高級索引
print("after:\n", sess.run([a, update])[0])
sess.close()
輸出:
before:
[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]
after:
[[11 12 13 14]
[ 4 5 6 7]
[15 16 17 18]]
in tensorflow 2.1
用 tf.tensor_scatter_nd_update
,見 [8]。
import tensorflow as tf
a = tf.Variable(tf.reshape(tf.range(12), [3, 4]))
print("before:\n", a)
idx = tf.constant([[0], [2]]) # index 升維
val = tf.constant([
[11, 12, 13, 14],
[15, 16, 17, 18]
])
a.assign(tf.tensor_scatter_nd_update(a, idx, val)) # 左值用高級索引
print("after:\n", a)
輸出:
before:
<tf.Variable 'Variable:0' shape=(3, 4) dtype=int32, numpy=
array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]], dtype=int32)>
after:
<tf.Variable 'Variable:0' shape=(3, 4) dtype=int32, numpy=
array([[11, 12, 13, 14],
[ 4, 5, 6, 7],
[15, 16, 17, 18]], dtype=int32)>
References
- TensorFlow - numpy-like tensor indexing
- Generalize slicing and slice assignment ops (including gather and scatter) #206
- TF 中的 indexing 和 slicing
- tf.gather和tf.gather_nd的詳細用法–tensorflow通過索引取tensor裏的數據
- 從Tensorflow中從另一箇中挑選隨機張量
- tf.gather tf.gather_nd 和 tf.batch_gather 使用方法
- Tensorflow深度學習之三十二: tf.scatter_nd_update
- tf.tensor_scatter_nd_update