很多框架中提供的矩陣乘法都是出於簡化計算的考慮,很多情況下在進行計算時候都會牽扯到 batch size 這一個維度,這就使得很多矩陣的計算是三維的,Pytorch中的bmm()函數就可以很方便的實現三維數組的乘法,而不用拆成二維數組使用for循環解決。在查資料的時候發現有些博客寫的有些小地方不太對,而且有很多提問都是關於 bmm()函數具體是如何計算的,因此記錄。
1.torch.bmm()
函數定義:
def bmm(self: Tensor,
mat2: Tensor,
*,
out: Optional[Tensor] = None) -> Tensor
函數的傳入參數很簡單,兩個三維矩陣而已,只是要注意這兩個矩陣的shape有一些要求:
res = torch.bmm(ma, mb)
ma: [a, b, c]
mb: [a, c, d]
也就是說兩個tensor的第一維是相等的,然後第一個數組的第三維和第二個數組的第二維度要求一樣,對於剩下的則不做要求,其實這裏的意思已經很明白了,兩個三維矩陣的乘法其實就是保持第一維度不變,每次相當於一個切片做二維矩陣的乘法,對於上面的矩陣來說,就是 for i in range(a)
然後 ma[i] * mb[i]
,這是一個熟悉的二維矩陣乘法,兩個矩陣的shape分別是[b, c]
和[c, d]
。因此,輸出的結果的shape也很明顯了:[a, b, d]
。下面驗證一下:
2.驗證
首先創建兩個tensor:
a = torch.linspace(1, 24, 24).view(2, 3, 4) # shape [2, 3, 4]
b = torch.linspace(1, 16, 16).view(2, 4, 2) # shape [2, 4, 2]
兩個tensor分別是:
tensor([[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.]],
[[13., 14., 15., 16.],
[17., 18., 19., 20.],
[21., 22., 23., 24.]]])
tensor([[[ 1., 2.],
[ 3., 4.],
[ 5., 6.],
[ 7., 8.]],
[[ 9., 10.],
[11., 12.],
[13., 14.],
[15., 16.]]])
接下來分別使用bmm函數和for循環方式實現乘法:
c = torch.bmm(a, b)
print(c)
d = np.array([torch.mm(a[i], b[i]).numpy() for i in range(len(a))])
print(d)
輸出分別是:
tensor([[[ 50., 60.],
[ 114., 140.],
[ 178., 220.]],
[[ 706., 764.],
[ 898., 972.],
[1090., 1180.]]])
[[[ 50. 60.]
[ 114. 140.]
[ 178. 220.]]
[[ 706. 764.]
[ 898. 972.]
[1090. 1180.]]]
也可以使用函數檢查一下:
print((d == c.numpy()).all())
輸出:True
3.更實際一點的想法
就像剛纔所說的那樣,只要根據實際的情況考慮一下,這個函數的計算過程很好理解,由於 batch size的引入,所以處理數據的時候很容易出現三維數組,例如處理文本計算attention權重的時候,很容易得到的權重矩陣shape是 [batch_size, sequence_length]
,然後需要相乘的隱狀態矩陣是 [batch_size, sequence_length, hidden_size]
。按照attention的計算方式,實際上就是權重矩陣中每一行的數值分別乘以隱狀態矩陣中每一行的對應位置的隱狀態,這個過程當然可以寫循環,也可以簡單的使用bmm函數計算,先將權重矩陣reshape成 [batch_size, 1, sequence_length]
然後bmm(weigths_matrix, hidden_matrix)
然後得到的結果就是attention計算的結果了。