PyTorch - torch.gather
flyfish
這是一篇讓您能懂的torch.gather的文章,下面的例子比官網更容易說明該函數的用法
函數的作用,沿dim指定的軸收集值(Gathers values along an axis specified by dim),相當於 我們有一個二維表根據指定的索引值,把數據取出來。例如一個索引中存儲了二維表中每行的最大值或者最小值,我們就可以使用該函數把二維表中每行的最大值都取出來。
重要的3個參數
input 表示輸入的tensor,從哪裏找數據
dim 按照哪一維找數據
indiex 按照該維找數據的索引有哪些
input (Tensor) – the source tensor
dim (int) – the axis along which to index
index (LongTensor) – the indices of elements to gather
二維舉例,我們想構建這樣的一個二維表,假設 二維表是 a[i][j] 形式,
a[i][j] i表示0維,j表示1維
a[0][0]=0 | a[0][1]=1 | a[0][2]=2 | a[0][3]=3 |
---|---|---|---|
a[1][0]=4 | a[1][1]=5 | a[1][2]=6 | a[1][3]=7 |
a[2][0]=8 | a[2][1]=9 | a[2][2]=10 | a[2][3]=11 |
a[3][0]=12 | a[3][1]=13 | a[3][2]=14 | a[3][3]=15 |
import torch
a = torch.arange(0, 16).view(4, 4)
print(a)
# tensor([[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11],
# [12, 13, 14, 15]])
#---------------------------------------------------------------------------------
index_1 = torch.LongTensor([[0, 1, 2, 3]])
print(a.gather(0, index_1))# 作用在0維上索引是index_1,相當於i=index_1, j是不變的0,1,2,3
#輸出結果
#tensor([[ 0, 5, 10, 15]])
# 從0維找數據,相當於
# a[i][j]
# a[0][0] = 0
# a[1][1] = 5
# a[2][2] = 10
# a[3][3] = 15
#---------------------------------------------------------------------------------
index_2 = torch.LongTensor([[3, 2, 1, 0]])
print(a.gather(0, index_2))#作用在0維上索引是index_2,相當於i=index_2, j是不變的0,1,2,3
#輸出結果
#tensor([[12, 9, 6, 3]])
# 從0維找數據,相當於
a[i][j]
a[3][0] = 12
a[2][1] = 9
a[1][2] = 6
a[0][3] = 3
#---------------------------------------------------------------------------------
index_3 = torch.LongTensor([[0, 1, 2, 3]]).t()
print(a.gather(1, index_3))#作用在1維上索引是index_3,相當於i是不變的0,1,2,3,j=index_3
#輸出結果
# tensor([[ 0],
# [ 5],
# [10],
# [15]])
a[i][j]
a[0][0] = 0
a[1][1] = 5
a[2][2] = 10
a[3][3] = 15
#---------------------------------------------------------------------------------
如果不想 加.t() 可以換成下面的形式
index_4 = torch.LongTensor([[3], [2], [1], [0]])
print(a.gather(1, index_4))##作用在1維上索引是index_4,相當於i是不變的0,1,2,3,j=index_4
#輸出結果
# tensor([[ 3],
# [ 6],
# [ 9],
# [12]])
# a[i][j]
# a[0][3] = 3
# a[1][2] = 6
# a[2][1] = 9
# a[3][0] = 12
#---------------------------------------------------------------------------------
再隨意舉個例子
index_5 = torch.LongTensor([[3, 3, 3, 3]])
print(a.gather(0, index_5))#作用在0維上索引是index_2,相當於i=index_5, j是不變的0,1,2,3
#輸出結果
#tensor([[12, 13, 14, 15]])
# 從0維找數據,相當於
a[i][j]
a[3][0] = 12
a[3][1] = 13
a[3][2] = 14
a[3][3] = 15