文章目錄
1. 點乘——torch.mul(a, b)
點乘都是broadcast的,可以用torch.mul(a, b)
實現,也可以直接用*
實現。
當a, b維度不一致時,會自動填充到相同維度相點乘。
import torch
a = torch.ones(3,4)
print(a)
b = torch.Tensor([1,2,3]).reshape((3,1))
print(b)
print(torch.mul(a, b))
2. 矩陣乘
矩陣相乘有torch.mm(a, b)
和torch.matmul(a, b)
兩個函數。
前一個是針對二維矩陣,後一個是高維。當torch.mm(a, b)
用於大於二維時將報錯。
2.1. 二維矩陣乘——torch.mm(a, b)
import torch
a = torch.ones(3,4)
print(a)
b = torch.ones(4,2)
print(b)
print(torch.mm(a, b))
當torch.mm(a, b)
用於大於二維時將報錯:
2.2. 高維矩陣乘——torch.matmul(a, b)
torch.matmul(a, b)
可以用於二維:
import torch
a = torch.ones(3,4)
print(a)
b = torch.ones(4,2)
print(b)
print(torch.matmul(a, b))
torch.matmul(a, b)
可以用於高維:
import torch
a = torch.ones(3,1,2)
print(a)
b = torch.ones(3,2,2)
print(b)
print(torch.matmul(a, b))
3. 高維的Tensor相乘維度要求
兩個Tensor維度要求:
- "2維以上"的尺寸必須完全對應相等;
- "2維"具有實際意義的單位,只要滿足矩陣相乘的尺寸規律即可。
3.1. 對於維數相同的張量
A.shape =(b,m,n);B.shape = (b,n,k)
numpy.matmul(A,B) 結果shape爲(b,m,k)
要求第一維度相同,後兩個維度能滿足矩陣相乘條件。
import torch
a = torch.ones(3,1,2)
print(a)
b = torch.ones(3,2,2)
print(b)
print(torch.matmul(a, b))
3.2. 對於維數不一樣的張量
比如 A.shape =(m,n); B.shape = (b,n,k); C.shape=(k,l)
numpy.matmul(A,B) 結果shape爲(b,m,k)
numpy.matmul(B,C) 結果shape爲(b,n,l)
2D張量要和3D張量的後兩個維度滿足矩陣相乘條件。
import torch
a = torch.ones(1,2)
print(a)
b = torch.ones(2,2,3)
print(b)
c = torch.ones(3,1)
print(b)
print(torch.matmul(a, b))
print(torch.matmul(b, c))