官網上關於tf.scatter_nd的介紹比較簡單,這裏提供更多例子。幫助理解。
例1:
import tensorflow as tf
indices = tf.constant([[0,0], [1,2]])
updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],
[7, 7, 7, 7], [8, 8, 8, 8]],
[[5, 5, 5, 5], [6, 6, 6, 6],
[7, 7, 7, 7], [8, 8, 8, 8]]])
shape = tf.constant([4, 4, 4])
scatter = tf.scatter_nd(indices, updates, shape)
with tf.Session() as sess:
print(sess.run(scatter))
結果:
[[[5 5 5 5]
[0 0 0 0]
[0 0 0 0]
[0 0 0 0]]
[[0 0 0 0]
[0 0 0 0]
[6 6 6 6]
[0 0 0 0]]
[[0 0 0 0]
[0 0 0 0]
[0 0 0 0]
[0 0 0 0]]
[[0 0 0 0]
[0 0 0 0]
[0 0 0 0]
[0 0 0 0]]]
例2:
import tensorflow as tf
indices = tf.constant([[0,0,0], [0,1,2]])
updates = tf.constant([[5, 5, 5, 5], [6, 6, 6, 6]])
shape = tf.constant([1, 4, 4, 4])
scatter = tf.scatter_nd(indices, updates, shape)
with tf.Session() as sess:
print(sess.run(scatter))
結果:
[[[[5 5 5 5]
[0 0 0 0]
[0 0 0 0]
[0 0 0 0]]
[[0 0 0 0]
[0 0 0 0]
[6 6 6 6]
[0 0 0 0]]
[[0 0 0 0]
[0 0 0 0]
[0 0 0 0]
[0 0 0 0]]
[[0 0 0 0]
[0 0 0 0]
[0 0 0 0]
[0 0 0 0]]]]
例3:
import tensorflow as tf
indices = tf.constant([[1,0,0], [1,1,2]])
updates = tf.constant([[5, 5, 5, 5], [6, 6, 6, 6]])
shape = tf.constant([1, 4, 4, 4])
scatter = tf.scatter_nd(indices, updates, shape)
with tf.Session() as sess:
print(sess.run(scatter))
結果:
QUIRES failed at scatter_nd_op.cc:119 : Invalid argument: Invalid indices: [0,0] = [1, 0, 0] does not index into [1,4,4,4]
[[[[0 0 0 0]
[0 0 0 0]
[0 0 0 0]
[0 0 0 0]]
[[0 0 0 0]
[0 0 0 0]
[0 0 0 0]
[0 0 0 0]]
[[0 0 0 0]
[0 0 0 0]
[0 0 0 0]
[0 0 0 0]]
[[0 0 0 0]
[0 0 0 0]
[0 0 0 0]
[0 0 0 0]]]]