torch.Tensor.scatter_()
是torch.gather()
函數的方向反向操作。兩個函數可以看成一對兄弟函數。gather
用來解碼one hot,scatter_
用來編碼one hot。
scatter_(dim, index, src) → Tensor
dim (python:int)
– 用來尋址的座標軸index (LongTensor)
– 索引src(Tensor)
–用來scatter的源張量,以防value未被指定。value(python:float)
– 用來scatter的源張量,以防src未被指定。
現在我們來看看具體這麼用,看下面這個例子就一目瞭然了。
- dim =0
import torch
x = torch.tensor([[0.9413, 0.9476, 0.1104, 0.9898, 0.6443],
[0.6913, 0.8924, 0.7530, 0.8874, 0.0557]])
result = torch.zeros(3, 5)
indices = torch.tensor([[0, 1, 2, 0, 0],
[2, 0, 0, 1, 2]])
result.scatter_(0, indices, x)
輸出爲
tensor([[0.9413, 0.8924, 0.7530, 0.9898, 0.6443],
[0.0000, 0.9476, 0.0000, 0.8874, 0.0000],
[0.6913, 0.0000, 0.1104, 0.0000, 0.0557]])
dim = 0的情形:
比如上例中,dim=0,所以根據這個規則來self[index[i][j]][j] = src[i][j]來確定替換規則。
index中的值決定了src中的值在result中的放置位置。
dim=0時,則將列固定起來,先看第0列:
對於第0行,首先找到x的第0列第0行的值爲0.9413,然後在用index[0][0]的值來找將要在result中放置的位置。
在這個例子中,index[0][0]=0, 所以0.9413將放置在result[0][0]這個位置。
對於result中的各項,他們的尋址過程如下:
x[0][1] = 0.9476 -> indices[0][1]=1 -> result[ index = 1 ][1] = 0.9476
x[1][3] = 0.8874 -> indices[1][3]=1 -> result[ index = 1 ][3] = 0.8874
依此類推。
以下爲dim = 1的情形:
x[0][0] = 0.9413 -> indices[0][0]=0 -> result[0][index = 0] = 0.9413
x[0][3] = 0.9898 -> indices[0][3]=0 -> result[0][index = 0] = 0.9898 ## 將上一步的值覆蓋了
x[0][4] = 0.6443 -> indices[0][4]=0 -> result[0][index = 0] = 0.6443 ## 再次將上一步的值覆蓋了
因此result[0][0]的值爲0.6443.
dim = 1
x = torch.tensor([[0.9413, 0.9476, 0.1104, 0.9898, 0.6443],
[0.6913, 0.8924, 0.7530, 0.8874, 0.0557]])
result = torch.zeros(3, 5)
indices = torch.tensor([[0, 1, 2, 0, 0],
[2, 0, 0, 1, 2]])
result.scatter_(1, indices, x)
輸出爲
tensor([[0.6443, 0.9476, 0.1104, 0.0000, 0.0000],
[0.7530, 0.8874, 0.0557, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
用於產生one hot編碼的向量
當沒有src
值時,則所有用於填充的值均爲value
值。
需要注意的時候,這個時候index.shape[dim]
必須與result.shape[dim]
相等,否則會報錯。
result = torch.zeros(3, 5)
indices = torch.tensor([[0, 1, 2, 0, 0],
[2, 0, 3, 1, 2],
[2, 1, 3, 1, 4]])
result.scatter_(1, indices, value=1)
輸出爲
tensor([[1., 1., 1., 0., 0.],
[1., 1., 1., 1., 0.],
[0., 1., 1., 1., 1.]])
例如 indices = [1,2,3,4,5],將他轉換爲one-hot的形式.
indices = torch.tensor(list(range(5))).view(5,1)
result = torch.zeros(5, 5)
result.scatter_(1, indices, 1)
輸出爲
tensor([[1., 0., 0., 0., 0.],
[0., 1., 0., 0., 0.],
[0., 0., 1., 0., 0.],
[0., 0., 0., 1., 0.],
[0., 0., 0., 0., 1.]])