篩選排除還沒找到答案:
取數運算
正好遇到一個需求。
我有m行k列的一個表a,和一個長爲m的索引列表b。
b中儲存着,取每行第幾列的元素。
這種情況下,你用普通的索引是會失效的。
import torch
a= torch.LongTensor([[1,2,3],[4,5,6]])
b= torch.LongTensor([0,1])
錯誤寫法:
c= a[b]
print(c)
結果是第1行和第2行
方法1:
import torch
conf_data = torch.LongTensor([[1, 2, 3], [4, 5, 6]])
b = torch.LongTensor([0, 1])
index_num = torch.arange(0, conf_data.size(0))
print(conf_data[index_num,b])
經過一番查找,發現我們可以用神奇的torch.gather()函數
import torch
a= torch.LongTensor([[1,2,3],[4,5,6]])
b= torch.LongTensor([0,1]).view(2,1)
c= torch.gather(input=a,dim=1,index=b)
print(c)
#tensor([[1],
[5]])