torch.mul
用法
torch.mul(input1, input2, out=None) #對位元素相乘
功能
1、當 input1是矩陣/向量 和 input2 是矩陣/向量,則對位相乘
2、當 input1 是矩陣/向量,input2 是標量,則 input1 的所有元素乘上input2
3、當 input1 是矩陣/向量,input2 是向量/矩陣,則 input2 / input1 先進行廣播,再對位相乘
舉例
#1 input1 和input2 的size相同,對位相乘
a = torch.tensor([[ 1.8351, 2.1536],
[-0.8320, -1.4578]])
b = torch.tensor([[2.9355, 0.3450],
[2.9355, 0.3450]])
c = torch.mul(a,b)
tensor([[ 5.3869, 0.7429],
[-2.4423, -0.5029]])
#3 input1 是矩陣,input2 是向量,則input2 先進行廣播,再對位相乘
a = torch.tensor([[ 1.8351, 2.1536],
[-0.8320, -1.4578]])
b = torch.tensor([2.9355, 0.3450])
c = torch.mul(a,b)
tensor([[ 5.3869, 0.7429],
[-2.4423, -0.5029]])
可見,#1 與 #3 結果相同。
torch.mm
用法
torch.mm(input1, input2, out=None) #二維矩陣相乘
功能
處理二維矩陣的乘法 (a, b) × (b, d) = (a, d),而且也只能處理二維矩陣,其他維度要用torch.matmul
舉例
mat1 = torch.randn(2, 3)
mat2 = torch.randn(3, 3)
mat3 = torch.mm(mat1, mat2)
tensor([[ 0.4851, 0.5037, -0.3633],
[-0.0760, -3.6705, 2.4784]])
torch.bmm
用法
torch.bmm(mat1, mat2, out=None)#三維矩陣,第一維是batch_size,後兩維進行矩陣的乘法
功能
看函數名就知道,在torch.mm的基礎上加了個batch計算,不能廣播
舉例
mat1 = torch.randn(6, 2, 3)#batch_size = 6
mat2 = torch.randn(6, 3, 4)
mat3 = torch.bmm(mat1, mat2)#torch.Size([6, 2, 4])
tensor([[[-5.5404e-02, -1.2719e+00, -1.3952e+00, 7.2475e-01],
[ 1.0943e+00, 2.1826e+00, -4.4239e-01, -1.0643e+00]],
[[ 1.1785e+00, -4.9125e-01, -3.4894e-01, -2.1170e-02],
[-6.4008e-01, -2.4427e-03, -3.1276e-01, -4.5647e-01]],
[[-2.9938e-01, 7.6840e-01, -2.7852e-01, 5.4946e-01],
[ 4.2854e-01, 1.8301e+00, 1.7477e-02, -1.4107e+00]],
[[-2.7399e-01, 1.2810e+00, 1.8456e+00, -5.5862e-01],
[ 1.0337e+00, 1.3213e+00, 7.3194e-01, 3.9463e-01]],
[[-1.3685e-01, -9.7863e-02, -3.3586e-01, 1.9415e-01],
[-3.7319e+00, -1.0287e+00, -2.8267e+00, 1.6140e+00]],
[[-2.6132e+00, 1.2601e+00, 2.4735e+00, -5.1219e-01],
[-3.9365e+00, 1.1015e+00, 5.8874e-01, 3.0009e-01]]])
torch.matmul
用法
torch.matmul(input1, input2, out=None)#適用性最多,能處理batch、廣播等
功能
1、適用性最多的,能處理batch、廣播的矩陣乘法
2、input1 是一維,input2 是二維,那麼給input1 提供一個維度(相當於 input1.unsqueeze(0)),再進行向量乘矩陣
3、帶有batch的情況,可保留batch計算
4、維度不同時,可先廣播,再batch計算
舉例
1、vector x vector
a = torch.randn(3)
b = torch.randn(3)
c = torch.matmul(a, b)
print(c, c.size())
#
tensor(1.2123) torch.Size([])
2、matrix x vector
a = torch.randn(3, 4)
b = torch.randn(4)
c = torch.matmul(a, b)
print(c.size())
#
torch.Size([3])
3、3dmatrix x 2dmatrix / vector(broadcasted)
a = torch.randn(10, 3, 4)
b = torch.randn(4, 6)
c = torch.randn(4)
d = torch.matmul(a, b)
e = torch.matmul(a, c)
print(c.size(), e.size())
#
torch.Size([10, 3, 6]) torch.Size([10, 3])
4、3dmatrix x 3dmatrix
此時與torch.bmm等效
a = torch.randn(2, 3, 4)
b = torch.randn(2, 4, 6)
c = torch.matmul(a, b)
d = torch.bmm(a, b)
print(c == d)
print(c.size())
#
tensor([[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
torch.Size([2, 3, 6])
總結
對位相乘用torch.mul
二維矩陣乘法用torch.mm
batch三維矩陣乘法用torch.bmm
batch、廣播矩陣乘法用torch.matmul