前言
這兩個函數,其實本來有一個大佬寫的比較清楚了,但是說實話,總是給忘具體使用細節。我還是自己寫一個更清晰的吧。
官方文檔
scatter_()
scatter_(input, dim, index, src) → Tensor
其實這樣寫會造成迷惑,建議這麼按下面的理解:
理解
input.scatter_(dim, index, src) → Tensor
- input: 我們需要插入數據的起源
tensor
;也就是想要改變內部數據的tensor
- dim:我們想要從哪個維度去改
input
數據 - index:給出改的元素索引,也就是位置,說在“座標”可能好理解一點。
- src:準備好的插入的數據。
重點
scatter()
和scatter_()
都是一個東西;舉例如下
區別:是否改變input
:
# 準備數據
_input = torch.rand(2,3)
src = 3
# _input →
tensor([[0.8648, 0.0797, 0.5591],
[0.0143, 0.9793, 0.5106]])
'1.使用scatter()'
_input.scatter_(0, torch.tensor([[0, 1, 0], [1, 0, 1]]), src)
#output →
tensor([[3., 3., 3.],
[3., 3., 3.]])
# 檢查_input是否被改變
'發現 _input 被改變了'
tensor([[3., 3., 3.],
[3., 3., 3.]])
'2.使用scatter_(),你會發現_input沒有改變'
# output →
tensor([[0.8648, 0.0797, 0.5591],
[0.0143, 0.9793, 0.5106]])
- 再來說說最難理解的
index
。
# 首先確定數據
# _input →
tensor([[0.8648, 0.0797, 0.5591],
[0.0143, 0.9793, 0.5106]])
# index →
torch.tensor([[0, 1, 0], [1, 0, 1]])
# src →
3
第一點:以上面爲例:
dim=0
時候,必須index
的dim=0爲3
:你看上面的index
是torch.tensor([[0, 1, 0], [1, 0, 1]])
,len([0,1,0])
是不是3。dim=1
時候,必須index
在dim=1的位置等於2
。
你們可以試試不改index
,只更改裏面的dim
測試一下
第二點: index裏面的元素必須是整數,大小必須在0
到(index.size(dim)-1)
- dim=0:
index.size(0)
爲2,元素必須在0
到1
,因此裏面元素都是0
和1
。 - dim=1 : 同理,
index.size(1)
爲3, 裏面元素必須在小於3.
第三點:index
是有它的套路的。很頭疼。
通常來說,對於二維的input
dim=0
src[i][j] = input[index[i][j]][j]
還是按上面的例子。
dim=1
src[i][j] = input[i][index[i][j]]