1維的可以直接取值,
import torch
anch_ious = torch.Tensor([[1, 2, 3], [4, 5, 6]]).view(-1)
neg_count=4
top_data,index= torch.topk(anch_ious, neg_count, dim=0, largest=True, sorted=True, out=None)
print(anch_ious[index])
2維以上就行了:
能返回index,但是不能根據index獲取到值
需要根據指定維度取值,用gather
import torch
anch_ious = torch.Tensor([[1, 2, 3], [4, 5, 6]])
neg_count=2
top_data,index= torch.topk(anch_ious, neg_count, dim=1, largest=True, sorted=True, out=None)
print(top_data)
# b = torch.LongTensor([0, 1]).view(2, 1)
c = torch.gather(input=anch_ious, dim=1, index=index)
print(c)