torch.mul | torch.mm | torch.bmm | torch.matmul的區別和使用

torch.mul

用法

torch.mul(input1, input2, out=None) #對位元素相乘

功能
1、當 input1是矩陣/向量 和 input2 是矩陣/向量,則對位相乘
2、當 input1 是矩陣/向量,input2 是標量,則 input1 的所有元素乘上input2
3、當 input1 是矩陣/向量,input2 是向量/矩陣,則 input2 / input1 先進行廣播,再對位相乘

舉例
#1 input1 和input2 的size相同,對位相乘

a = torch.tensor([[ 1.8351,  2.1536],
   		          [-0.8320, -1.4578]])
b = torch.tensor([[2.9355, 0.3450],
    	          [2.9355, 0.3450]])
c = torch.mul(a,b)
tensor([[ 5.3869,  0.7429],
        [-2.4423, -0.5029]])

#3 input1 是矩陣,input2 是向量,則input2 先進行廣播,再對位相乘

a = torch.tensor([[ 1.8351,  2.1536],
   		          [-0.8320, -1.4578]])
b = torch.tensor([2.9355, 0.3450])
c = torch.mul(a,b)
tensor([[ 5.3869,  0.7429],
        [-2.4423, -0.5029]])

可見,#1 與 #3 結果相同。

torch.mm

用法

torch.mm(input1, input2, out=None) #二維矩陣相乘

功能
處理二維矩陣的乘法 (a, b) × (b, d) = (a, d),而且也只能處理二維矩陣,其他維度要用torch.matmul
舉例

mat1 = torch.randn(2, 3)
mat2 = torch.randn(3, 3)
mat3 = torch.mm(mat1, mat2)
tensor([[ 0.4851,  0.5037, -0.3633],
        [-0.0760, -3.6705,  2.4784]])

torch.bmm

用法

torch.bmm(mat1, mat2, out=None)#三維矩陣,第一維是batch_size,後兩維進行矩陣的乘法

功能
看函數名就知道,在torch.mm的基礎上加了個batch計算,不能廣播
舉例

mat1 = torch.randn(6, 2, 3)#batch_size = 6
mat2 = torch.randn(6, 3, 4)
mat3 = torch.bmm(mat1, mat2)#torch.Size([6, 2, 4])
tensor([[[-5.5404e-02, -1.2719e+00, -1.3952e+00,  7.2475e-01],
         [ 1.0943e+00,  2.1826e+00, -4.4239e-01, -1.0643e+00]],

        [[ 1.1785e+00, -4.9125e-01, -3.4894e-01, -2.1170e-02],
         [-6.4008e-01, -2.4427e-03, -3.1276e-01, -4.5647e-01]],

        [[-2.9938e-01,  7.6840e-01, -2.7852e-01,  5.4946e-01],
         [ 4.2854e-01,  1.8301e+00,  1.7477e-02, -1.4107e+00]],

        [[-2.7399e-01,  1.2810e+00,  1.8456e+00, -5.5862e-01],
         [ 1.0337e+00,  1.3213e+00,  7.3194e-01,  3.9463e-01]],

        [[-1.3685e-01, -9.7863e-02, -3.3586e-01,  1.9415e-01],
         [-3.7319e+00, -1.0287e+00, -2.8267e+00,  1.6140e+00]],

        [[-2.6132e+00,  1.2601e+00,  2.4735e+00, -5.1219e-01],
         [-3.9365e+00,  1.1015e+00,  5.8874e-01,  3.0009e-01]]])

torch.matmul

用法

torch.matmul(input1, input2, out=None)#適用性最多,能處理batch、廣播等

功能
1、適用性最多的,能處理batch、廣播的矩陣乘法
2、input1 是一維,input2 是二維,那麼給input1 提供一個維度(相當於 input1.unsqueeze(0)),再進行向量乘矩陣
3、帶有batch的情況,可保留batch計算
4、維度不同時,可先廣播,再batch計算
舉例
1、vector x vector

a = torch.randn(3)
b = torch.randn(3)
c = torch.matmul(a, b)
print(c, c.size())
#
tensor(1.2123) torch.Size([])

2、matrix x vector

a = torch.randn(3, 4)
b = torch.randn(4)
c = torch.matmul(a, b)
print(c.size())
#
torch.Size([3])

3、3dmatrix x 2dmatrix / vector(broadcasted)

a = torch.randn(10, 3, 4)
b = torch.randn(4, 6)
c = torch.randn(4)

d = torch.matmul(a, b)
e = torch.matmul(a, c)
print(c.size(), e.size())
#
torch.Size([10, 3, 6]) torch.Size([10, 3])

4、3dmatrix x 3dmatrix
此時與torch.bmm等效

a = torch.randn(2, 3, 4)
b = torch.randn(2, 4, 6)
c = torch.matmul(a, b)
d = torch.bmm(a, b)
print(c == d)
print(c.size())
#
tensor([[[1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1]],

        [[1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
torch.Size([2, 3, 6])

總結

對位相乘torch.mul

二維矩陣乘法torch.mm

batch三維矩陣乘法torch.bmm

batch、廣播矩陣乘法torch.matmul

參考Blog

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章