PyTorch - torch.gather

PyTorch - torch.gather

flyfish

這是一篇讓您能懂的torch.gather的文章,下面的例子比官網更容易說明該函數的用法
函數的作用,沿dim指定的軸收集值(Gathers values along an axis specified by dim),相當於 我們有一個二維表根據指定的索引值,把數據取出來。例如一個索引中存儲了二維表中每行的最大值或者最小值,我們就可以使用該函數把二維表中每行的最大值都取出來。
重要的3個參數
input 表示輸入的tensor,從哪裏找數據
dim 按照哪一維找數據
indiex 按照該維找數據的索引有哪些

input (Tensor) – the source tensor
dim (int) – the axis along which to index
index (LongTensor) – the indices of elements to gather

二維舉例,我們想構建這樣的一個二維表,假設 二維表是 a[i][j] 形式,
a[i][j] i表示0維,j表示1維

a[0][0]=0 a[0][1]=1 a[0][2]=2 a[0][3]=3
a[1][0]=4 a[1][1]=5 a[1][2]=6 a[1][3]=7
a[2][0]=8 a[2][1]=9 a[2][2]=10 a[2][3]=11
a[3][0]=12 a[3][1]=13 a[3][2]=14 a[3][3]=15
import torch
a = torch.arange(0, 16).view(4, 4)
print(a)
# tensor([[ 0,  1,  2,  3],
#         [ 4,  5,  6,  7],
#         [ 8,  9, 10, 11],
#         [12, 13, 14, 15]])

#---------------------------------------------------------------------------------

index_1 = torch.LongTensor([[0, 1, 2, 3]])
print(a.gather(0, index_1))# 作用在0維上索引是index_1,相當於i=index_1, j是不變的0,1,2,3
#輸出結果
#tensor([[ 0,  5, 10, 15]])
# 從0維找數據,相當於
# a[i][j]
# a[0][0] = 0
# a[1][1] = 5
# a[2][2] = 10
# a[3][3] = 15

#---------------------------------------------------------------------------------

index_2 = torch.LongTensor([[3, 2, 1, 0]])
print(a.gather(0, index_2))#作用在0維上索引是index_2,相當於i=index_2, j是不變的0,1,2,3
#輸出結果
#tensor([[12,  9,  6,  3]])
# 從0維找數據,相當於
a[i][j]
a[3][0] = 12
a[2][1] = 9
a[1][2] = 6
a[0][3] = 3

#---------------------------------------------------------------------------------

index_3 = torch.LongTensor([[0, 1, 2, 3]]).t()
print(a.gather(1, index_3))#作用在1維上索引是index_3,相當於i是不變的0,1,2,3,j=index_3
#輸出結果
# tensor([[ 0],
#         [ 5],
#         [10],
#         [15]])

a[i][j]
a[0][0] = 0
a[1][1] = 5
a[2][2] = 10
a[3][3] = 15

#---------------------------------------------------------------------------------
如果不想 加.t() 可以換成下面的形式

index_4 = torch.LongTensor([[3], [2], [1], [0]])
print(a.gather(1, index_4))##作用在1維上索引是index_4,相當於i是不變的0,1,2,3,j=index_4
#輸出結果
# tensor([[ 3],
#         [ 6],
#         [ 9],
#         [12]])
# a[i][j]
# a[0][3] = 3
# a[1][2] = 6
# a[2][1] = 9
# a[3][0] = 12

#---------------------------------------------------------------------------------
再隨意舉個例子

index_5 = torch.LongTensor([[3, 3, 3, 3]])
print(a.gather(0, index_5))#作用在0維上索引是index_2,相當於i=index_5, j是不變的0,1,2,3
#輸出結果
#tensor([[12, 13, 14, 15]])
# 從0維找數據,相當於
a[i][j]
a[3][0] = 12
a[3][1] = 13
a[3][2] = 14
a[3][3] = 15
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章