tensorflow用gather/scatter實現advanced indexing

tf.gather

numpy 支持用 ndarray 索引:

import numpy as np

arr = np.arange(9).reshape(3, 3)
idx = np.array([0, 2])
print(arr[idx])

但 tensorflow 中非 scalar 的 tensor 可以直接用作下標:

import tensorflow as tf

arr = tf.reshape(tf.range(9), [3, 3])
idx = tf.constant([0, 2])
one = tf.constant(1)

with tf.Session() as sess:
    print(sess.run(arr[one]))  # 可以
    print(sess.run(arr[0:1]))  # 可以
    print(sess.run(arr[idx]))  # 報錯

要實現類似的功能,用 tf.gather

import tensorflow as tf

arr = tf.reshape(tf.range(9), [3, 3])
idx = tf.constant([0, 2])

with tf.Session() as sess:
    print(sess.run(tf.gather(arr, idx)))

tf.gather_nd

這次的目標是:給出矩陣 An×mA_{n\times m} 和索引向量 bn×1b_{n\times1},取各 A[i][b[i]],即 A 的每行都取一個元素,下標由 b[i] 決定。用到tf.gather_nd
tf.gather_nd 用元素的「座標」選元素,即傳入的第二個參數indices是要選的那些元素的座標的序列。例如對於上述的目標,indices 就是各[i, b[i]]組成的序列。b[i] 已經有,只要補上行座標 i 就行。

Example

  • 這裏同時也實現了 tensorflow 的 tensor 隨機索引,即生成隨機索引向量 b,用於索引張量 A 的分量。
import tensorflow as tf
import numpy as np

n = 3
m = 4

# 備選數組
A = tf.constant(np.arange(n * m).reshape(n, m))
# 隨機生成列 id
b = tf.random_uniform([n, 1],
                      minval=0, maxval=m,  # 列 id 範圍:[0, m)
                      dtype=tf.int32)
# 補上行 id
row_id = tf.range(n, dtype="int32")[:, None] # 形狀:[n, 1]

#print(row_id.shape.as_list())
#print(b.shape.as_list())

# 拼在一起組成完整座標
idx = tf.concat([row_id, b], axis=1)
# 選元素
elem = tf.gather_nd(A, idx)

with tf.Session() as sess:
    A, b, r, i, e = sess.run([A, b, row_id, idx, elem])
    print("A:\n", A)
    print("b:\n", b)
    print("row id:\n", r)
    print("indices:\n", i)
    print("elem:\n", e)

結果:

A:
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]]
b:
[[0]
 [1]
 [1]]
row id:
[[0]
 [1]
 [2]]
indices:
[[0 0]
 [1 1]
 [2 1]]
elem:
[0 5 9]

(0,5,9)( A[0][0], A[1][1], A[2][1] )

tf.batch_gather

tf.batch_gathertf.argsort/tf.top_k 實現矩陣分行排序。場景是:用 A 的行數據 argsort 得出的 indices 來對 B 行數據排序。

import tensorflow as tf

a = tf.constant([
    [1, 0, 3, 2, 5],
    [4, 7, 9, 8, 6]
])
b = tf.reshape(tf.range(10, 20), [2, 5])

# k_idx = tf.argsort(ham)  # tf 1.12 無 `argsort`…用 top_k 代替
k_val, k_idx = tf.math.top_k(- a, a.shape[1])  # minus for ascending
b_sort = tf.batch_gather(b, k_idx)

with tf.Session() as sess:
    a, b, k_idx, b_sort = sess.run([a, b, k_idx, b_sort])
    print("a:\n", a)
    print("b:\n", b)
    print("k_idx:\n", k_idx)
    print("b_sort:\n", b_sort)

結果:

a:
[[1 0 3 2 5]
 [4 7 9 8 6]]
b:
[[10 11 12 13 14]
 [15 16 17 18 19]]
k_idx:
[[1 0 3 2 4]
 [0 4 1 3 2]]
b_sort:
[[11 10 13 12 14]
 [15 19 16 18 17]]

tf.scatter_nd_update / tf.tensor_scatter_nd_update

將高級索引用於左值。注意 tensorflow 需要對 index 升維,詳見代碼。

in numpy

這個功能對應的 numpy 示例

import numpy as np

a = np.arange(12).reshape(3, 4)
print("before:\n", a)
idx = np.array([0, 2])  # numpy 的 index 不 需要升維
val = np.array([
    [11, 12, 13, 14],
    [15, 16, 17, 18]
])
a[idx] = val  # 左值用高級索引
print("after:\n", a)

輸出:

before:
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]]
after:
[[11 12 13 14]
 [ 4  5  6  7]
 [15 16 17 18]]

in tensorflow 1.12

tf.scatter_nd_update,見 [7]。

import tensorflow as tf

sess = tf.Session()

a = tf.Variable(tf.reshape(tf.range(12), [3, 4]))
sess.run(tf.global_variables_initializer())
print("before:\n", sess.run(a))
idx = tf.constant([[0], [2]])  # index 升維
val = tf.constant([
    [11, 12, 13, 14],
    [15, 16, 17, 18]
])
update = a.assign(tf.scatter_nd_update(a, idx, val))  # 左值用高級索引
print("after:\n", sess.run([a, update])[0])

sess.close()

輸出:

before:
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]]
after:
[[11 12 13 14]
 [ 4  5  6  7]
 [15 16 17 18]]

in tensorflow 2.1

tf.tensor_scatter_nd_update,見 [8]。

import tensorflow as tf

a = tf.Variable(tf.reshape(tf.range(12), [3, 4]))
print("before:\n", a)
idx = tf.constant([[0], [2]])  # index 升維
val = tf.constant([
    [11, 12, 13, 14],
    [15, 16, 17, 18]
])
a.assign(tf.tensor_scatter_nd_update(a, idx, val))  # 左值用高級索引
print("after:\n", a)

輸出:

before:
<tf.Variable 'Variable:0' shape=(3, 4) dtype=int32, numpy=
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]], dtype=int32)>
after:
<tf.Variable 'Variable:0' shape=(3, 4) dtype=int32, numpy=
array([[11, 12, 13, 14],
       [ 4,  5,  6,  7],
       [15, 16, 17, 18]], dtype=int32)>

References

  1. TensorFlow - numpy-like tensor indexing
  2. Generalize slicing and slice assignment ops (including gather and scatter) #206
  3. TF 中的 indexing 和 slicing
  4. tf.gather和tf.gather_nd的詳細用法–tensorflow通過索引取tensor裏的數據
  5. 從Tensorflow中從另一箇中挑選隨機張量
  6. tf.gather tf.gather_nd 和 tf.batch_gather 使用方法
  7. Tensorflow深度學習之三十二: tf.scatter_nd_update
  8. tf.tensor_scatter_nd_update
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章