pytorch tensor 筛选排除

 

筛选排除还没找到答案:

 

取数运算
 

正好遇到一个需求。

我有m行k列的一个表a,和一个长为m的索引列表b。

b中储存着,取每行第几列的元素。

这种情况下,你用普通的索引是会失效的。

import torch
a= torch.LongTensor([[1,2,3],[4,5,6]])
b= torch.LongTensor([0,1])
 

错误写法:
c= a[b]
print(c)
 结果是第1行和第2行
 

方法1:

    import torch

    conf_data = torch.LongTensor([[1, 2, 3], [4, 5, 6]])
    b = torch.LongTensor([0, 1])

    index_num = torch.arange(0, conf_data.size(0))

    print(conf_data[index_num,b])


 

经过一番查找,发现我们可以用神奇的torch.gather()函数

import torch
 
a= torch.LongTensor([[1,2,3],[4,5,6]])
b= torch.LongTensor([0,1]).view(2,1)
 
c= torch.gather(input=a,dim=1,index=b)
print(c)
#tensor([[1],
        [5]])
 
 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章