pytorch 中涉及到矩阵之间的乘法(torch.mul, *, torch.mm, torch.matmul, @)

最近在学习pytorch,过程中遇到一些问题,这里权当笔记记录下来,同时也供大家参考。

下面简单回顾一下矩阵中的乘法:(严谨的说,其实应该说是矩阵乘法和矩阵内积)
1、矩阵乘法
  矩阵乘法也就是我们常说的矩阵向量积(也称矩阵外积矩阵叉乘
      它要求前一个矩阵的行数等于后一个矩阵的列数,其计算方法是计算结果的每一行元素为前一个矩阵的每一行元素与后一个矩阵的每一列对应元素相乘,之后求和。下面232*3矩阵与353*5矩阵为例:
[111111]×[123451234512345]=[36912153691215] \begin{gathered} \begin{bmatrix} 1 & 1 & 1\\ 1 & 1 & 1 \end{bmatrix} \times \begin{bmatrix} 1 & 2 & 3 & 4 & 5 \\ 1 & 2 & 3 & 4 & 5 \\ 1 & 2 & 3 & 4 & 5 \end{bmatrix} \end{gathered}=\begin{bmatrix} 3 & 6 & 9 & 12 & 15 \\ 3 & 6 & 9 & 12 & 15 \end{bmatrix}
其计算方法为:
11+11+11=a11=3,   12+12+12=a12=61*1+1*1+1*1=a11=3, \, \,\,1*2+1*2+1*2=a12=6……
其中a11为第一行第一个元素,以此类推
2、矩阵内积
  矩阵点法也就是我们常说的矩阵点乘
       即矩阵的对应元素相乘,故它要求两个矩阵形状一样,下面232*3矩阵与232*3矩阵为例:
[111111].[123456]=[123456] \begin{gathered} \begin{bmatrix} 1 & 1 & 1\\ 1 & 1 & 1 \end{bmatrix} . \begin{bmatrix}1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix} \end{gathered}=\begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix}

  在进入正题之前,先扯点儿闲篇——大家应该都知道numpy(至少听说过,python的一个数值计算库,pytorch不火的时候,numpy还是很好用的),而pytorch,主要特点是可以使用GPU加速运算,但是计算上和numpy有很多类似之处,那好,介绍pytorch的矩阵乘法之前,先说说numpy中ndarray中矩阵的乘法:
  numpyt中点乘使用*或者np.multiply(),而叉乘使用@, np.dot(), np.matmul()
测试测序如下:

import numpy as np

print("numpy")
A = np.array([[1, 2, 3, 6], [2, 3, 4, 3], [2, 3, 4, 4]])
B = np.array([[1, 0, 1, 4], [2, 1, -1, 0], [2, 1, 5, 0]])
C = np.array([[1, 0, 3], [0, 1, 2], [-1, 0, 1], [-1, 0, 1]])

# 对应位置相乘,点乘
print("矩阵对应元素相乘 点乘")
print("*运算符\n", A*B)
print("np.multiply\n", np.multiply(A, B))

print("矩阵相乘 叉乘")
print("A.dot\n", A.dot(C))  # 矩阵乘法
print("@运算符\n", A@C)
print("np.matmul\n", np.matmul(A, C), '\n')

有人会问这里的dot和matmul函数有什么区别

请移步numpy中dot和matmul的区别

  而pytorch中用法略有不同,其中点乘使用*或者np.mul(),而叉乘使用@, torch.mm(), torch.matmul()(注意这里没有dot函数,使用torch.mm函数)

import torch
print("pytorch")
a = torch.ones(2, 3)
c = torch.FloatTensor([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5]])
#b = torch.randint(1, 9, (2, 3))
b = torch.FloatTensor([[1, 2, 3], [4, 5, 6]])

print("矩阵对应元素相乘 点乘")
print("*运算符\n", a * b)
print("torch.mul\n", torch.mul(a, b))

print("矩阵相乘 叉乘")
print("@运算符\n", a@c)
print("torch.mm\n", torch.mm(a, c))
print("torch.matmul\n", torch.matmul(a, c), "\n")

输出结果:
在这里插入图片描述
下面又有一个问题,torch.mm()和torch.matmul()到底有什么区别?
可以参考官网教程
如果你懒得看,你可以看下面这两张我从官网上截的图
torch.mm函数用法
在这里插入图片描述
当然了,如果还是难以理解的话,请移步这里

参考:
[1]https://blog.csdn.net/She_Said/article/details/98034841
[2]https://www.jb51.net/article/177406.htm
[3]https://pytorch.org/docs/stable/torch.html

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章