pytorch中的 scatter_()函數使用和詳解

scatter(dim, index, src)的三個參數爲:

(1)dim:沿着哪個維度進行索引

(2)index: 用來scatter的元素索引

(3)src: 用來scatter的源元素,可以使一個標量也可以是一個張量

官方給的例子爲三維情況下的例子:

y = y.scatter(dim,index,src)

#則結果爲:
y[ index[i][j][k]  ] [j][k] = src[i][j][k] # if dim == 0
y[i] [ index[i][j][k] ] [k] = src[i][j][k] # if dim == 1
y[i][j] [ index[i][j][k] ]  = src[i][j][k] # if dim == 2

如果是二維的例子,則應該對應下面的情況:

y = y.scatter(dim,index,src)

#則:
y [ index[i][j] ] [j] = src[i][j] #if dim==0
y[i] [ index[i][j] ]  = src[i][j] #if dim==1 

我們舉一個實際的例子:

import torch

x = torch.randn(2,4)
print(x)
y = torch.zeros(3,4)
y = y.scatter_(0,torch.LongTensor([[2,1,2,2],[0,2,1,1]]),x)
print(y)


#結果爲:
tensor([[-0.9669, -0.4518,  1.7987,  0.1546],
        [-0.1122, -0.7998,  0.6075,  1.0192]])
tensor([[-0.1122,  0.0000,  0.0000,  0.0000],
        [ 0.0000, -0.4518,  0.6075,  1.0192],
        [-0.9669, -0.7998,  1.7987,  0.1546]])


'''
scatter後:
y[ index[0][0] ] [0] = src[0][0] -> y[2][0]=-0.9669

y[ index[1][3] ] [3] = src[1][3] -> y[1][3]=1.10192

'''

#如果src爲標量,則代表着將對應位置的數值改爲src這個標量

那麼這個函數有什麼作用呢?其實可以利用這個功能將pytorch 中mini batch中的返回的label(特指[ 1,0,4,9 ],即size爲[4]這樣的label)轉爲one-hot類型的label,舉例子如下:

import torch

mini_batch = 4
out_planes = 6
out_put = torch.rand(mini_batch, out_planes)
softmax = torch.nn.Softmax(dim=1)
out_put = softmax(out_put)

print(out_put)
label = torch.tensor([1,3,3,5])
one_hot_label = torch.zeros(mini_batch, out_planes).scatter_(1,label.unsqueeze(1),1)
print(one_hot_label)

上述的這個例子假設是一個分類問題,我設置out_planes=6,是假設總共有6類,mini_batch是我們送入的網絡的每個mini_batch的樣本數量,這裏我們不設置網絡,直接假設網絡的輸出爲一個隨機的張量 ,通常我們要對這個輸出進行softmax歸一化,此時就代表着其屬於每個類別的概率了。說到這裏都不是重點,就是爲了方便理解如何使用scatter,將size爲[mini_batch]的張量,轉爲size爲[mini_batch, out_palnes]的張量,並且這個生成的張量的每個行向量都是one-hot類型的了。通過看下面的輸出結果就完全能夠理解了,不理解,給我留言,我給你解釋清楚。

tensor([[0.1202, 0.2120, 0.1252, 0.1127, 0.2314, 0.1985],
        [0.1707, 0.1227, 0.2282, 0.0918, 0.1845, 0.2021],
        [0.1629, 0.1936, 0.1277, 0.1204, 0.1845, 0.2109],
        [0.1226, 0.1524, 0.2315, 0.2027, 0.1907, 0.1001]])
tensor([1, 3, 3, 5])
tensor([[0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 1.]])

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章