pytorch函数之torch中的几种乘法 #点乘torch.mm() #矩阵乘torch.mul(),torch.matmul() #高维Tensor相乘维度要求

1. 点乘——torch.mul(a, b)

点乘都是broadcast的,可以用torch.mul(a, b)实现,也可以直接用*实现。

当a, b维度不一致时,会自动填充到相同维度相点乘。

import torch

a = torch.ones(3,4)
print(a)
b = torch.Tensor([1,2,3]).reshape((3,1))
print(b)

print(torch.mul(a, b))

在这里插入图片描述

2. 矩阵乘

矩阵相乘有torch.mm(a, b)torch.matmul(a, b)两个函数。

前一个是针对二维矩阵,后一个是高维。当torch.mm(a, b)用于大于二维时将报错。

2.1. 二维矩阵乘——torch.mm(a, b)

import torch

a = torch.ones(3,4)
print(a)
b = torch.ones(4,2)
print(b)

print(torch.mm(a, b))

在这里插入图片描述
torch.mm(a, b)用于大于二维时将报错:

2.2. 高维矩阵乘——torch.matmul(a, b)

torch.matmul(a, b)可以用于二维:

import torch

a = torch.ones(3,4)
print(a)
b = torch.ones(4,2)
print(b)

print(torch.matmul(a, b))

torch.matmul(a, b)可以用于高维:

import torch

a = torch.ones(3,1,2)
print(a)
b = torch.ones(3,2,2)
print(b)

print(torch.matmul(a, b))

3. 高维的Tensor相乘维度要求

两个Tensor维度要求:

  • "2维以上"的尺寸必须完全对应相等;
  • "2维"具有实际意义的单位,只要满足矩阵相乘的尺寸规律即可。

3.1. 对于维数相同的张量

A.shape =(b,m,n);B.shape = (b,n,k)
numpy.matmul(A,B) 结果shape为(b,m,k)

要求第一维度相同,后两个维度能满足矩阵相乘条件。

import torch

a = torch.ones(3,1,2)
print(a)
b = torch.ones(3,2,2)
print(b)

print(torch.matmul(a, b))

在这里插入图片描述

3.2. 对于维数不一样的张量

比如 A.shape =(m,n); B.shape = (b,n,k); C.shape=(k,l)

numpy.matmul(A,B) 结果shape为(b,m,k)

numpy.matmul(B,C) 结果shape为(b,n,l)

2D张量要和3D张量的后两个维度满足矩阵相乘条件。

import torch

a = torch.ones(1,2)
print(a)
b = torch.ones(2,2,3)
print(b)
c = torch.ones(3,1)
print(b)

print(torch.matmul(a, b))
print(torch.matmul(b, c))

在这里插入图片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章