tensorflow用gather/scatter实现advanced indexing

tf.gather

numpy 支持用 ndarray 索引:

import numpy as np

arr = np.arange(9).reshape(3, 3)
idx = np.array([0, 2])
print(arr[idx])

但 tensorflow 中非 scalar 的 tensor 可以直接用作下标:

import tensorflow as tf

arr = tf.reshape(tf.range(9), [3, 3])
idx = tf.constant([0, 2])
one = tf.constant(1)

with tf.Session() as sess:
    print(sess.run(arr[one]))  # 可以
    print(sess.run(arr[0:1]))  # 可以
    print(sess.run(arr[idx]))  # 报错

要实现类似的功能,用 tf.gather

import tensorflow as tf

arr = tf.reshape(tf.range(9), [3, 3])
idx = tf.constant([0, 2])

with tf.Session() as sess:
    print(sess.run(tf.gather(arr, idx)))

tf.gather_nd

这次的目标是:给出矩阵 An×mA_{n\times m} 和索引向量 bn×1b_{n\times1},取各 A[i][b[i]],即 A 的每行都取一个元素,下标由 b[i] 决定。用到tf.gather_nd
tf.gather_nd 用元素的「座标」选元素,即传入的第二个参数indices是要选的那些元素的座标的序列。例如对于上述的目标,indices 就是各[i, b[i]]组成的序列。b[i] 已经有,只要补上行座标 i 就行。

Example

  • 这里同时也实现了 tensorflow 的 tensor 随机索引,即生成随机索引向量 b,用于索引张量 A 的分量。
import tensorflow as tf
import numpy as np

n = 3
m = 4

# 备选数组
A = tf.constant(np.arange(n * m).reshape(n, m))
# 随机生成列 id
b = tf.random_uniform([n, 1],
                      minval=0, maxval=m,  # 列 id 范围:[0, m)
                      dtype=tf.int32)
# 补上行 id
row_id = tf.range(n, dtype="int32")[:, None] # 形状:[n, 1]

#print(row_id.shape.as_list())
#print(b.shape.as_list())

# 拼在一起组成完整座标
idx = tf.concat([row_id, b], axis=1)
# 选元素
elem = tf.gather_nd(A, idx)

with tf.Session() as sess:
    A, b, r, i, e = sess.run([A, b, row_id, idx, elem])
    print("A:\n", A)
    print("b:\n", b)
    print("row id:\n", r)
    print("indices:\n", i)
    print("elem:\n", e)

结果:

A:
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]]
b:
[[0]
 [1]
 [1]]
row id:
[[0]
 [1]
 [2]]
indices:
[[0 0]
 [1 1]
 [2 1]]
elem:
[0 5 9]

(0,5,9)( A[0][0], A[1][1], A[2][1] )

tf.batch_gather

tf.batch_gathertf.argsort/tf.top_k 实现矩阵分行排序。场景是:用 A 的行数据 argsort 得出的 indices 来对 B 行数据排序。

import tensorflow as tf

a = tf.constant([
    [1, 0, 3, 2, 5],
    [4, 7, 9, 8, 6]
])
b = tf.reshape(tf.range(10, 20), [2, 5])

# k_idx = tf.argsort(ham)  # tf 1.12 无 `argsort`…用 top_k 代替
k_val, k_idx = tf.math.top_k(- a, a.shape[1])  # minus for ascending
b_sort = tf.batch_gather(b, k_idx)

with tf.Session() as sess:
    a, b, k_idx, b_sort = sess.run([a, b, k_idx, b_sort])
    print("a:\n", a)
    print("b:\n", b)
    print("k_idx:\n", k_idx)
    print("b_sort:\n", b_sort)

结果:

a:
[[1 0 3 2 5]
 [4 7 9 8 6]]
b:
[[10 11 12 13 14]
 [15 16 17 18 19]]
k_idx:
[[1 0 3 2 4]
 [0 4 1 3 2]]
b_sort:
[[11 10 13 12 14]
 [15 19 16 18 17]]

tf.scatter_nd_update / tf.tensor_scatter_nd_update

将高级索引用于左值。注意 tensorflow 需要对 index 升维,详见代码。

in numpy

这个功能对应的 numpy 示例

import numpy as np

a = np.arange(12).reshape(3, 4)
print("before:\n", a)
idx = np.array([0, 2])  # numpy 的 index 不 需要升维
val = np.array([
    [11, 12, 13, 14],
    [15, 16, 17, 18]
])
a[idx] = val  # 左值用高级索引
print("after:\n", a)

输出:

before:
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]]
after:
[[11 12 13 14]
 [ 4  5  6  7]
 [15 16 17 18]]

in tensorflow 1.12

tf.scatter_nd_update,见 [7]。

import tensorflow as tf

sess = tf.Session()

a = tf.Variable(tf.reshape(tf.range(12), [3, 4]))
sess.run(tf.global_variables_initializer())
print("before:\n", sess.run(a))
idx = tf.constant([[0], [2]])  # index 升维
val = tf.constant([
    [11, 12, 13, 14],
    [15, 16, 17, 18]
])
update = a.assign(tf.scatter_nd_update(a, idx, val))  # 左值用高级索引
print("after:\n", sess.run([a, update])[0])

sess.close()

输出:

before:
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]]
after:
[[11 12 13 14]
 [ 4  5  6  7]
 [15 16 17 18]]

in tensorflow 2.1

tf.tensor_scatter_nd_update,见 [8]。

import tensorflow as tf

a = tf.Variable(tf.reshape(tf.range(12), [3, 4]))
print("before:\n", a)
idx = tf.constant([[0], [2]])  # index 升维
val = tf.constant([
    [11, 12, 13, 14],
    [15, 16, 17, 18]
])
a.assign(tf.tensor_scatter_nd_update(a, idx, val))  # 左值用高级索引
print("after:\n", a)

输出:

before:
<tf.Variable 'Variable:0' shape=(3, 4) dtype=int32, numpy=
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]], dtype=int32)>
after:
<tf.Variable 'Variable:0' shape=(3, 4) dtype=int32, numpy=
array([[11, 12, 13, 14],
       [ 4,  5,  6,  7],
       [15, 16, 17, 18]], dtype=int32)>

References

  1. TensorFlow - numpy-like tensor indexing
  2. Generalize slicing and slice assignment ops (including gather and scatter) #206
  3. TF 中的 indexing 和 slicing
  4. tf.gather和tf.gather_nd的详细用法–tensorflow通过索引取tensor里的数据
  5. 从Tensorflow中从另一个中挑选随机张量
  6. tf.gather tf.gather_nd 和 tf.batch_gather 使用方法
  7. Tensorflow深度学习之三十二: tf.scatter_nd_update
  8. tf.tensor_scatter_nd_update
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章