pytorch tensor 篩選排除

 

篩選排除還沒找到答案:

 

取數運算
 

正好遇到一個需求。

我有m行k列的一個表a,和一個長爲m的索引列表b。

b中儲存着,取每行第幾列的元素。

這種情況下,你用普通的索引是會失效的。

import torch
a= torch.LongTensor([[1,2,3],[4,5,6]])
b= torch.LongTensor([0,1])
 

錯誤寫法:
c= a[b]
print(c)
 結果是第1行和第2行
 

方法1:

    import torch

    conf_data = torch.LongTensor([[1, 2, 3], [4, 5, 6]])
    b = torch.LongTensor([0, 1])

    index_num = torch.arange(0, conf_data.size(0))

    print(conf_data[index_num,b])


 

經過一番查找,發現我們可以用神奇的torch.gather()函數

import torch
 
a= torch.LongTensor([[1,2,3],[4,5,6]])
b= torch.LongTensor([0,1]).view(2,1)
 
c= torch.gather(input=a,dim=1,index=b)
print(c)
#tensor([[1],
        [5]])
 
 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章