torch_scatter.scatter_add、Tensor.scatter_add_ 、Tensor.scatter_、Tensor.scatter_add 、Tensor.scatter

 torch_scatter.scatter_add

官方文檔:torch_scatter.scatter_add(srcindexdim=-1out=Nonedim_size=Nonefill_value=0)

Sums all values from the src tensor into out at the indices specified in the index tensor along a given axis dim. For each value in src, its output index is specified by its index in input for dimensions outside of dim and by the corresponding value in index for dimension dim. If multiple indices reference the same location, their contributions add.

看着挺疑惑的,自己試了一把:

src = torch.tensor([10, 20, 30, 40, 1, 2, 2, 2, 9])
index = torch.tensor([2, 1, 1, 1, 1, 1, 1, 1, 0])
out=scatter_add(src, index)
print(out)

輸出結果爲:tensor([ 9, 97, 10])

說白了就是:index就是out的下標,將src所有和此下標對應的值加起來,就是out的值。

例如上面的例子:index中等於1的,對應於src是【20, 30, 40, 1, 2, 2, 2】,將這些值加起來是97,於是,out[1]=97

同理:out[0]=src[8]=9     out[2]=src[0]=10

 

另一個函數

Tensor.scatter_add_

官方文檔:

scatter_add_(self, dim, index, other):
For a 3-D tensor, :attr:`self` is updated as::

    self[index[i][j][k]][j][k] += other[i][j][k]  # if dim == 0
    self[i][index[i][j][k]][k] += other[i][j][k]  # if dim == 1
    self[i][j][index[i][j][k]] += other[i][j][k]  # if dim == 2

官方例子:

            >>> x = torch.rand(2, 5)
            >>> x
            tensor([[0.7404, 0.0427, 0.6480, 0.3806, 0.8328],
                    [0.7953, 0.2009, 0.9154, 0.6782, 0.9620]])
            >>> torch.ones(3, 5).scatter_add_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
            tensor([[1.7404, 1.2009, 1.9154, 1.3806, 1.8328],
                    [1.0000, 1.0427, 1.0000, 1.6782, 1.0000],
                    [1.7953, 1.0000, 1.6480, 1.0000, 1.9620]])

以index來遍歷,就比較容易看懂。self中並不是每個值都要改變的。

以上面爲例 index[0][0]=0  self[index[0][0]][0]=self[0][0] =self[0][0]+ x[0][0]=1 +0.7404=1.7404

                   index[0][1]=1  self[index[0][1]][1]=self[1][1] =self[1][1]+ x[0][1] =1 +0.0427 =1.0427

                   。。。

                  以此類推,將index遍歷一遍,就得到最終的結果

所以,self中需要改變的是index中列出的座標,其他的是不動的。


Tensor.scatter_

scatter_(self, dim, index, src)

和Tensor.scatter_add_的區別是直接將src中的值填充到self中,不做相加

例子:

>>> x = torch.rand(2, 5)
            >>> x
            tensor([[ 0.3992,  0.2908,  0.9044,  0.4850,  0.6004],
                    [ 0.5735,  0.9006,  0.6797,  0.4152,  0.1732]])
            >>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
            tensor([[ 0.3992,  0.9006,  0.6797,  0.4850,  0.6004],
                    [ 0.0000,  0.2908,  0.0000,  0.4152,  0.0000],
                    [ 0.5735,  0.0000,  0.9044,  0.0000,  0.1732]])
        
            >>> z = torch.zeros(2, 4).scatter_(1, torch.tensor([[2], [3]]), 1.23)
            >>> z
            tensor([[ 0.0000,  0.0000,  1.2300,  0.0000],
                    [ 0.0000,  0.0000,  0.0000,  1.2300]])

另外,pytorch中還有

scatter_add和scatter函數,和上面兩個函數不同的是這個兩個函數不改變self,會返回結果值;上面兩個函數(scatter_add_和scatter_)是直接在原數據self上進行修改
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章