參考:pytorch中的nn.Bilinear的計算原理詳解
代碼實現
使用numpy實現Bilinear(來自參考資料):
print('learn nn.Bilinear')
m = nn.Bilinear(20, 30, 40)
input1 = torch.randn(128, 20)
input2 = torch.randn(128, 30)
output = m(input1, input2)
print(output.size())
arr_output = output.data.cpu().numpy()
weight = m.weight.data.cpu().numpy()
bias = m.bias.data.cpu().numpy()
x1 = input1.data.cpu().numpy()
x2 = input2.data.cpu().numpy()
print(x1.shape,weight.shape,x2.shape,bias.shape)
y = np.zeros((x1.shape[0],weight.shape[0]))
for k in range(weight.shape[0]):
buff = np.dot(x1, weight[k])
buff = buff * x2
buff = np.sum(buff,axis=1)
y[:,k] = buff
y += bias
dif = y - arr_output
print(np.mean(np.abs(dif.flatten())))
輸出結果:
可以看到我們自己用numpy實現的Bilinear跟調用pytorch的Bilinear的輸出結果的誤差在小數點後7位,通過編寫這個程序,現在可以理解Bilinear的計算過程了。需要注意的是Bilinear的weight是一個3維矩陣,這是跟nn.linear的一個最大區別。
首先,以weight的第0維開始,逐個遍歷weight的每一頁,當遍歷到第k頁時,輸入x1與weight[k,:,:]做矩陣乘法得到buff,然後buff與輸入x2做矩陣點乘得到新的buff,接下來對buff在第1個維度,即按行求和得到新的buff,這時把buff的值賦值給輸出y的第k列
遍歷完weight的每一頁之後,加上偏置項,這時候Bilinear的計算就完成了。爲了檢驗編寫的numpy程序是否正確,我們把輸出y跟調用pytorch的nn.Bilinear得到的輸出output轉成numpy形式的arr_output做誤差比較。
公式
設 形狀爲 , 形狀爲
nn.Bilinear
內部的參數形狀爲:
參數 ,令 ,其形狀爲 ;
參數
下述代碼使用nn.Bilinear
得到 ,其形狀爲 。
m=nn.Bilinear(input_1,input_2,output)
Y=m(X_1,X_2)
實際計算公式使用python語法可以表達爲:
其中與之間使用矩陣乘法,其結果與 使用逐元素乘法;
表示將在軸 1 方向上求和進行歸約(維度減一)。
表示將多個tensor在軸 1 方向上進行拼接。
最後的加法爲廣播加法,因爲加號左邊維度爲,而右邊爲