torch.mul和*等價(attetion中可以用到)
每行乘上不同元素
>>> a = torch.ones(3,4)
>>> a
tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]])
>>> b = torch.Tensor([1,2,3]).reshape((3,1))
>>> b
tensor([[1.],
[2.],
[3.]])
>>> a * b
tensor([[1., 1., 1., 1.],
[2., 2., 2., 2.],
[3., 3., 3., 3.]])
每列乘上不同元素
>>> a = torch.ones(3,4)
>>> a
tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]])
>>> b = torch.Tensor([1,2,3,4]).reshape((1,4))
>>> b
tensor([[1., 2., 3., 4.]])
>>> a*b
tensor([[1., 2., 3., 4.],
[1., 2., 3., 4.],
[1., 2., 3., 4.]])
>>> torch.mul(a,b)
tensor([[1., 2., 3., 4.],
[1., 2., 3., 4.],
[1., 2., 3., 4.]])
帶batch(mul和*會自動broadcaset到所以batch)
>>> a=torch.ones(2,3,4)
>>> a
tensor([[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]],
[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]]])
>>> b = torch.Tensor([1,2,3,4]).reshape((1,4))
>>> b
tensor([[1., 2., 3., 4.]])
>>> a*b
tensor([[[1., 2., 3., 4.],
[1., 2., 3., 4.],
[1., 2., 3., 4.]],
[[1., 2., 3., 4.],
[1., 2., 3., 4.],
[1., 2., 3., 4.]]])
>>> torch.mul(a,b)
tensor([[[1., 2., 3., 4.],
[1., 2., 3., 4.],
[1., 2., 3., 4.]],
[[1., 2., 3., 4.],
[1., 2., 3., 4.],
[1., 2., 3., 4.]]])
此外還有針對矩陣的乘法如:torch.dot() torch.mm() torch.bmm() torch.dot()是針對一維的向量進行點積。
In [252]: a=torch.randn(2,3)
In [253]: b=torch.randn(3,2)
In [254]: torch.dot(a,b)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-254-4939a8ae602a> in <module>()
----> 1 torch.dot(a,b)
RuntimeError: dot: Expected 1-D argument self, but got 2-D
In [255]: a=torch.randn(3)
In [256]: b=torch.randn(3)
In [257]: torch.dot(a,b)
Out[257]: tensor(0.0967)
torch.mm是針對矩陣的點積(只針對2維)
In [258]: a=torch.randn(2,3)
In [259]: b=torch.randn(3,2)
In [260]: torch.mm(a,b)
Out[260]:
tensor([[-1.2849, 0.1272],
[ 0.0600, -0.3183]])
In [261]: torch.mm(a,b).size()
Out[261]: torch.Size([2, 2])
torch.bmm()是針對一個batch的二維矩陣進行點積(假設batch_size=2)
In [262]: a=torch.randn(2,2,3)
In [263]: b=torch.randn(2,3,2)
In [264]: torch.mm(a,b).size()
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-264-6bdb1a27d804> in <module>()
----> 1 torch.mm(a,b).size()
RuntimeError: matrices expected, got 3D, 3D tensors at /Users/soumith/code/builder/wheel/pytorch-src/aten/src/TH/generic/THTensorMath.cpp:2065
In [265]: torch.bmm(a,b).size()
Out[265]: torch.Size([2, 2, 2])
那麼如果維度大於3,我們要針對數據集中單個的矩陣進行點積怎麼辦呢?比如在多頭attention中,batch=2,head=2
torch.matmul只針對最後倆個維度進行點積。
In [266]: b=torch.randn(2,2,3,2)
In [267]: a=torch.randn(2,2,2,3)
In [268]: torch.bmm(a,b).size()
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-268-3084f0e99edf> in <module>()
----> 1 torch.bmm(a,b).size()
RuntimeError: invalid argument 1: expected 3D tensor, got 4D at /Users/soumith/code/builder/wheel/pytorch-src/aten/src/TH/generic/THTensorMath.cpp:2304
In [269]: torch.matmul(a,b).size()
Out[269]: torch.Size([2, 2, 2, 2])
In [270]: a=torch.randn(2,2,3)
In [271]: b=torch.randn(2,3,2)
In [272]: torch.matmul(a,b).size()
Out[272]: torch.Size([2, 2, 2])
可以matmul同樣可以實現bmm