pytorch-全面講解函數topk, scatter, gather

這三個函數在pytorch中關於矩陣操作的非常實用的函數。我認爲要想熟練的使用pytorch,能夠靈活的使用這三個函數是至關重要的

三者的相同點:維度->數據的映射方式

因爲三者都存在相似的地方,所以我這裏放在一起來講。這個共同點就是index -> value的方式:這裏以官方給的gather函數對應爲例:

# for a 3-D data
out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

這樣一看,並不好理解,舉個例子:

  • 關於shape的變化:
    • 輸入爲[3, 4, 1]的數據x | 它的index爲[3, 2, 1]
    • x.gather(dim=1, index)輸出維度爲[3, 2, 1]。它保持另外兩維不變,僅在這一維上操作。
  • 關於數據的變化
    • idx中的數據代表在指定維度上的index。 在這裏插入圖片描述

topk

其實前面講的映射方式計算起來還是容易亂,不過幸好並不影響我們的使用。emm實在不能理解可以忽略,只需要知道在指定維度上操作即可

torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)
  • 主要用途:依照大小,從矩陣某維度取值和取索引。常與scatter、gather連用。
  • 函數返回兩個變量:value和index
  • 維度變化:假設指定維度爲1,則(b, n, m)-> (b, k, m)
  • 其它用途:topk的數據默認按照從大到小排列,因此我們可以當做矩陣中的數據排序來用,若largest=False則爲升序:

在這裏插入圖片描述

gather

torch.gather(input, dim, index, out=None) → Tensor
  • 用途:依照index來對矩陣進行取值
  • 函數返回與輸入idx維度相同的tensor
  • 維度變化:假設指定dim=1,index=(b,k,n),input=(b,m,n)。則輸出爲(b,k,n)。在維度1上按照index進行取值。

scatter

torch.scatter(input, dim, index, src) → Tensor
  • 用途:與gather類似,不過它並不用來取值。scatter用來更替矩陣中指定index位置的值。

  • 維度變化:假設指定dim=1,index=(b,k,n),input=(b,m,n),src=(b,m,n)。則輸出爲(b,k,n)。在維度1上按照index從src取值,然後替換到input上相同的index位置。

  • 兩種用法:

    • 一般要求source的維度爲input維度相同,如下例:

    在這裏插入圖片描述

    • 當然,也可以直接指定要替換的值,如下:

    在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章