one hot編碼:`torch.Tensor.scatter_()`函數用法詳解

torch.Tensor.scatter_()torch.gather()函數的方向反向操作。兩個函數可以看成一對兄弟函數。gather用來解碼one hot,scatter_用來編碼one hot。

scatter_(dim, index, src) → Tensor

  • dim (python:int) – 用來尋址的座標軸
  • index (LongTensor) – 索引
  • src(Tensor) –用來scatter的源張量,以防value未被指定。
  • value(python:float) – 用來scatter的源張量,以防src未被指定。

現在我們來看看具體這麼用,看下面這個例子就一目瞭然了。

  • dim =0
import torch

x = torch.tensor([[0.9413, 0.9476, 0.1104, 0.9898, 0.6443],
            [0.6913, 0.8924, 0.7530, 0.8874, 0.0557]])

result = torch.zeros(3, 5)
indices = torch.tensor([[0, 1, 2, 0, 0], 
                        [2, 0, 0, 1, 2]])
result.scatter_(0, indices, x)

輸出爲

tensor([[0.9413, 0.8924, 0.7530, 0.9898, 0.6443],
        [0.0000, 0.9476, 0.0000, 0.8874, 0.0000],
        [0.6913, 0.0000, 0.1104, 0.0000, 0.0557]])

dim = 0的情形:

比如上例中,dim=0,所以根據這個規則來self[index[i][j]][j] = src[i][j]來確定替換規則。

index中的值決定了src中的值在result中的放置位置。

dim=0時,則將列固定起來,先看第0列:

對於第0行,首先找到x的第0列第0行的值爲0.9413,然後在用index[0][0]的值來找將要在result中放置的位置。

在這個例子中,index[0][0]=0, 所以0.9413將放置在result[0][0]這個位置。

對於result中的各項,他們的尋址過程如下:

x[0][1] = 0.9476 -> indices[0][1]=1 -> result[ index = 1 ][1] = 0.9476

x[1][3] = 0.8874 -> indices[1][3]=1 -> result[ index = 1 ][3] = 0.8874

依此類推。

以下爲dim = 1的情形:

x[0][0] = 0.9413 -> indices[0][0]=0 -> result[0][index = 0] = 0.9413

x[0][3] = 0.9898 -> indices[0][3]=0 -> result[0][index = 0] = 0.9898 ## 將上一步的值覆蓋了

x[0][4] = 0.6443 -> indices[0][4]=0 -> result[0][index = 0] = 0.6443 ## 再次將上一步的值覆蓋了

因此result[0][0]的值爲0.6443.

dim = 1

x = torch.tensor([[0.9413, 0.9476, 0.1104, 0.9898, 0.6443],
                        [0.6913, 0.8924, 0.7530, 0.8874, 0.0557]])
result = torch.zeros(3, 5)
indices = torch.tensor([[0, 1, 2, 0, 0], 
                        [2, 0, 0, 1, 2]])
result.scatter_(1, indices, x)
            

輸出爲

tensor([[0.6443, 0.9476, 0.1104, 0.0000, 0.0000],
        [0.7530, 0.8874, 0.0557, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])

用於產生one hot編碼的向量

當沒有src值時,則所有用於填充的值均爲value值。

需要注意的時候,這個時候index.shape[dim]必須與result.shape[dim]相等,否則會報錯。

result = torch.zeros(3, 5)
indices = torch.tensor([[0, 1, 2, 0, 0], 
                        [2, 0, 3, 1, 2],
                        [2, 1, 3, 1, 4]])
result.scatter_(1, indices, value=1)        

輸出爲

tensor([[1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 0.],
        [0., 1., 1., 1., 1.]])

例如 indices = [1,2,3,4,5],將他轉換爲one-hot的形式.

indices = torch.tensor(list(range(5))).view(5,1)
result = torch.zeros(5, 5)
result.scatter_(1, indices, 1)        

輸出爲

tensor([[1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1.]])
發佈了25 篇原創文章 · 獲贊 3 · 訪問量 1萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章