tensorflow tf.nn.top_k 生成mask 提取值

# 通過生成boolean tensor的辦法:
a = tf.convert_to_tensor([[40, 30, 20, 10], [10, 20, 30, 40]])
b = tf.nn.top_k(a, 2)

print(sess.run(b))
TopKV2(values=array([[40, 30],
   [40, 30]], dtype=int32), indices=array([[0, 1],
   [3, 2]], dtype=int32))

print(sess.run(b).values))
array([[40, 30],
       [40, 30]], dtype=int32)

kth = tf.reduce_min(b.values,1,keepdims=True) # 找出最小值
top2 = tf.greater_equal(a, kth) # 大於最小值的爲true
print(sess.run(top2))
array([[ True,  True, False, False],
       [False, False,  True,  True]], dtype=bool)
# 通過生成id後scatter的辦法:
import tensorflow as tf

# Input data
a = tf.placeholder(tf.float32, [None, None])
num_top = tf.placeholder(tf.int32, [])
# Find top elements
a_top, a_top_idx = tf.nn.top_k(a, num_top, sorted=False)
# Apply softmax
a_top_sm = tf.nn.softmax(a_top)
# Reconstruct into original shape
a_shape = tf.shape(a)
a_row_idx = tf.tile(tf.range(a_shape[0])[:, tf.newaxis], (1, num_top))
scatter_idx = tf.stack([a_row_idx, a_top_idx], axis=-1)  # 生成scatter_index
result = tf.scatter_nd(scatter_idx, a_top_sm, a_shape)  #生成矩陣
# Test
with tf.Session() as sess:
    result_val = sess.run(result, feed_dict={a: [[2, 5, 4, 7], [7, 5, 6, 8]], num_top: 2})
    print(result_val)

[[0.         0.11920291 0.         0.880797  ]
 [0.26894143 0.         0.         0.7310586 ]]
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章