pytorch函數之scatter()和scatter_()

前言

這兩個函數,其實本來有一個大佬寫的比較清楚了,但是說實話,總是給忘具體使用細節。我還是自己寫一個更清晰的吧。

官方文檔

scatter_()
scatter_(input, dim, index, src) → Tensor

其實這樣寫會造成迷惑,建議這麼按下面的理解:
理解

input.scatter_(dim, index, src) → Tensor
  • input: 我們需要插入數據的起源tensor;也就是想要改變內部數據的tensor
  • dim:我們想要從哪個維度去改input數據
  • index:給出改的元素索引,也就是位置,說在“座標”可能好理解一點。
  • src:準備好的插入的數據。

重點

  1. scatter()scatter_()都是一個東西;舉例如下

區別:是否改變input

# 準備數據
_input = torch.rand(2,3)
src = 3
# _input  →  
tensor([[0.8648, 0.0797, 0.5591],
        [0.0143, 0.9793, 0.5106]])
  
 '1.使用scatter()'
 _input.scatter_(0, torch.tensor([[0, 1, 0], [1, 0, 1]]), src)
 #output → 
 tensor([[3., 3., 3.],
        [3., 3., 3.]])
# 檢查_input是否被改變
'發現 _input 被改變了'
tensor([[3., 3., 3.],
        [3., 3., 3.]])
    
 '2.使用scatter_(),你會發現_input沒有改變'
 # output → 
 tensor([[0.8648, 0.0797, 0.5591],
        [0.0143, 0.9793, 0.5106]])
  1. 再來說說最難理解的index
# 首先確定數據
# _input  →  
tensor([[0.8648, 0.0797, 0.5591],
        [0.0143, 0.9793, 0.5106]])
 # index →
 torch.tensor([[0, 1, 0], [1, 0, 1]])
 # src →
 3

第一點:以上面爲例:

  • dim=0時候,必須index的dim=0爲3:你看上面的indextorch.tensor([[0, 1, 0], [1, 0, 1]])len([0,1,0])是不是3。
  • dim=1時候,必須index在dim=1的位置等於2

你們可以試試不改index,只更改裏面的dim測試一下

第二點: index裏面的元素必須是整數大小必須在0(index.size(dim)-1)

  • dim=0:index.size(0)爲2,元素必須在01,因此裏面元素都是01
  • dim=1 : 同理,index.size(1)爲3, 裏面元素必須在小於3.

第三點index是有它的套路的。很頭疼
通常來說,對於二維的input

  • dim=0
src[i][j] = input[index[i][j]][j]

還是按上面的例子。

  • dim=1
src[i][j] = input[i][index[i][j]]

算了,回頭再講,這個已經花我2小時了,我還得改文章。。。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章