pytorch函數之torch中的幾種乘法 #點乘torch.mm() #矩陣乘torch.mul(),torch.matmul() #高維Tensor相乘維度要求

1. 點乘——torch.mul(a, b)

點乘都是broadcast的,可以用torch.mul(a, b)實現,也可以直接用*實現。

當a, b維度不一致時,會自動填充到相同維度相點乘。

import torch

a = torch.ones(3,4)
print(a)
b = torch.Tensor([1,2,3]).reshape((3,1))
print(b)

print(torch.mul(a, b))

在這裏插入圖片描述

2. 矩陣乘

矩陣相乘有torch.mm(a, b)torch.matmul(a, b)兩個函數。

前一個是針對二維矩陣,後一個是高維。當torch.mm(a, b)用於大於二維時將報錯。

2.1. 二維矩陣乘——torch.mm(a, b)

import torch

a = torch.ones(3,4)
print(a)
b = torch.ones(4,2)
print(b)

print(torch.mm(a, b))

在這裏插入圖片描述
torch.mm(a, b)用於大於二維時將報錯:

2.2. 高維矩陣乘——torch.matmul(a, b)

torch.matmul(a, b)可以用於二維:

import torch

a = torch.ones(3,4)
print(a)
b = torch.ones(4,2)
print(b)

print(torch.matmul(a, b))

torch.matmul(a, b)可以用於高維:

import torch

a = torch.ones(3,1,2)
print(a)
b = torch.ones(3,2,2)
print(b)

print(torch.matmul(a, b))

3. 高維的Tensor相乘維度要求

兩個Tensor維度要求:

  • "2維以上"的尺寸必須完全對應相等;
  • "2維"具有實際意義的單位,只要滿足矩陣相乘的尺寸規律即可。

3.1. 對於維數相同的張量

A.shape =(b,m,n);B.shape = (b,n,k)
numpy.matmul(A,B) 結果shape爲(b,m,k)

要求第一維度相同,後兩個維度能滿足矩陣相乘條件。

import torch

a = torch.ones(3,1,2)
print(a)
b = torch.ones(3,2,2)
print(b)

print(torch.matmul(a, b))

在這裏插入圖片描述

3.2. 對於維數不一樣的張量

比如 A.shape =(m,n); B.shape = (b,n,k); C.shape=(k,l)

numpy.matmul(A,B) 結果shape爲(b,m,k)

numpy.matmul(B,C) 結果shape爲(b,n,l)

2D張量要和3D張量的後兩個維度滿足矩陣相乘條件。

import torch

a = torch.ones(1,2)
print(a)
b = torch.ones(2,2,3)
print(b)
c = torch.ones(3,1)
print(b)

print(torch.matmul(a, b))
print(torch.matmul(b, c))

在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章