這三個函數在pytorch中關於矩陣操作的非常實用的函數。我認爲要想熟練的使用pytorch,能夠靈活的使用這三個函數是至關重要的
三者的相同點:維度->數據的映射方式
因爲三者都存在相似的地方,所以我這裏放在一起來講。這個共同點就是index -> value的方式:這裏以官方給的gather函數對應爲例:
# for a 3-D data
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
這樣一看,並不好理解,舉個例子:
- 關於shape的變化:
- 輸入爲[3, 4, 1]的數據x | 它的index爲[3, 2, 1]
- x.gather(dim=1, index)輸出維度爲[3, 2, 1]。它保持另外兩維不變,僅在這一維上操作。
- 關於數據的變化
- idx中的數據代表在指定維度上的index。
topk
其實前面講的映射方式計算起來還是容易亂,不過幸好並不影響我們的使用。emm實在不能理解可以忽略,只需要知道在指定維度上操作即可
torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)
- 主要用途:依照大小,從矩陣某維度取值和取索引。常與scatter、gather連用。
- 函數返回兩個變量:value和index
- 維度變化:假設指定維度爲1,則(b, n, m)-> (b, k, m)
- 其它用途:topk的數據默認按照從大到小排列,因此我們可以當做矩陣中的數據排序來用,若largest=False則爲升序:
gather
torch.gather(input, dim, index, out=None) → Tensor
- 用途:依照index來對矩陣進行取值
- 函數返回與輸入idx維度相同的tensor
- 維度變化:假設指定dim=1,index=(b,k,n),input=(b,m,n)。則輸出爲(b,k,n)。在維度1上按照index進行取值。
scatter
torch.scatter(input, dim, index, src) → Tensor
-
用途:與gather類似,不過它並不用來取值。scatter用來更替矩陣中指定index位置的值。
-
維度變化:假設指定dim=1,index=(b,k,n),input=(b,m,n),src=(b,m,n)。則輸出爲(b,k,n)。在維度1上按照index從src取值,然後替換到input上相同的index位置。
-
兩種用法:
- 一般要求source的維度爲input維度相同,如下例:
- 當然,也可以直接指定要替換的值,如下: