pytorch中的nn.Bilinear

參考:pytorch中的nn.Bilinear的計算原理詳解

在這裏插入圖片描述

在這裏插入圖片描述

代碼實現

使用numpy實現Bilinear(來自參考資料):

print('learn nn.Bilinear')
m = nn.Bilinear(20, 30, 40)
input1 = torch.randn(128, 20)
input2 = torch.randn(128, 30)
output = m(input1, input2)
print(output.size())
arr_output = output.data.cpu().numpy()
 
weight = m.weight.data.cpu().numpy()
bias = m.bias.data.cpu().numpy()
x1 = input1.data.cpu().numpy()
x2 = input2.data.cpu().numpy()
print(x1.shape,weight.shape,x2.shape,bias.shape)
y = np.zeros((x1.shape[0],weight.shape[0]))
for k in range(weight.shape[0]):
    buff = np.dot(x1, weight[k])
    buff = buff * x2
    buff = np.sum(buff,axis=1)
    y[:,k] = buff
y += bias
dif = y - arr_output
print(np.mean(np.abs(dif.flatten())))

輸出結果:

在這裏插入圖片描述

可以看到我們自己用numpy實現的Bilinear跟調用pytorch的Bilinear的輸出結果的誤差在小數點後7位,通過編寫這個程序,現在可以理解Bilinear的計算過程了。需要注意的是Bilinear的weight是一個3維矩陣,這是跟nn.linear的一個最大區別。

首先,以weight的第0維開始,逐個遍歷weight的每一頁,當遍歷到第k頁時,輸入x1與weight[k,:,:]做矩陣乘法得到buff,然後buff與輸入x2做矩陣點乘得到新的buff,接下來對buff在第1個維度,即按行求和得到新的buff,這時把buff的值賦值給輸出y的第k列

遍歷完weight的每一頁之後,加上偏置項,這時候Bilinear的計算就完成了。爲了檢驗編寫的numpy程序是否正確,我們把輸出y跟調用pytorch的nn.Bilinear得到的輸出output轉成numpy形式的arr_output做誤差比較。

公式

X1X_1 形狀爲 [batch_size,input1][batch\_size,input_1]X2X_2 形狀爲 [batch_size,input2][batch\_size,input_2]

nn.Bilinear 內部的參數形狀爲:

參數 W:[output,input1,input2]W:[output,input_1,input_2],令 Wk=W[k,:,:],1koutputW_k=W[k,:,:], 1\leq k \leq output,其形狀爲 [input1,input2][input_1,input_2] ;

參數 b:[output]b:[output]

下述代碼使用nn.Bilinear得到 YY,其形狀爲 [batch_size,output][batch\_size,output]

m=nn.Bilinear(input_1,input_2,output)
Y=m(X_1,X_2)

實際計算公式使用python語法可以表達爲:

Y=concatenate([sum(X1WkX2,axis=1)forWkinW],axis=1)+bY=concatenate([sum(X_1W_k \odot X_2, axis=1) \quad for \quad W_k\quad in \quad W], axis=1)+b

其中X1X_1WkW_k之間使用矩陣乘法,其結果與 X2X_2 使用逐元素乘法;

sum(tensor,axis=1)sum(tensor,axis=1) 表示將在軸 1 方向上求和進行歸約(維度減一)。

concatenate(list_of_tensor,axis=1)concatenate(list\_of\_tensor,axis=1) 表示將多個tensor在軸 1 方向上進行拼接。

最後的加法爲廣播加法,因爲加號左邊維度爲[batch_size,output][batch\_size,output],而右邊爲 [output][output]

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章