torch.topk

 

1維的可以直接取值,

    import torch

    anch_ious = torch.Tensor([[1, 2, 3], [4, 5, 6]]).view(-1)


    neg_count=4
    top_data,index= torch.topk(anch_ious, neg_count, dim=0, largest=True, sorted=True, out=None)

    print(anch_ious[index])

2維以上就行了:

能返回index,但是不能根據index獲取到值

需要根據指定維度取值,用gather

    import torch

    anch_ious = torch.Tensor([[1, 2, 3], [4, 5, 6]])

    neg_count=2
    top_data,index= torch.topk(anch_ious, neg_count, dim=1, largest=True, sorted=True, out=None)

    print(top_data)

    # b = torch.LongTensor([0, 1]).view(2, 1)

    c = torch.gather(input=anch_ious, dim=1, index=index)
    print(c)

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章