【小白學PyTorch】10 pytorch常見運算詳解

參考目錄:

這一課主要是講解PyTorch中的一些運算,加減乘除這些,當然還有矩陣的乘法這些。這一課內容不多,作爲一個知識儲備。在後續的內容中,有用PyTorch來獲取EfficientNet預訓練模型以及一個貓狗給分類的實戰任務教學。

加減乘除就不多說了,+-*/

1 矩陣與標量

這個是矩陣(張量)每一個元素與標量進行操作。

import torch
a = torch.tensor([1,2])
print(a+1)
>>> tensor([2, 3])

2 哈達瑪積

這個就是兩個相同尺寸的張量相乘,然後對應元素的相乘就是這個哈達瑪積,也成爲element wise。

a = torch.tensor([1,2])
b = torch.tensor([2,3])
print(a*b)
print(torch.mul(a,b))
>>> tensor([2, 6])
>>> tensor([2, 6])

這個torch.mul()*是等價的。

當然,除法也是類似的:

a = torch.tensor([1.,2.])
b = torch.tensor([2.,3.])
print(a/b)
print(torch.div(a/b))
>>> tensor([0.5000, 0.6667])
>>> tensor([0.5000, 0.6667])

我們可以發現的torch.div()其實就是/, 類似的:torch.add就是+,torch.sub()就是-,不過符號的運算更簡單常用。

3 矩陣乘法

如果我們想實現線性代數中的矩陣相乘怎麼辦呢?

這樣的操作有三個寫法:

  • torch.mm()
  • torch.matmul()
  • @,這個需要記憶,不然遇到這個可能會挺矇蔽的
a = torch.tensor([1.,2.])
b = torch.tensor([2.,3.]).view(1,2)
print(torch.mm(a, b))
print(torch.matmul(a, b))
print(a @ b)

輸出結果:

tensor([[2., 3.],
        [4., 6.]])
tensor([[2., 3.],
        [4., 6.]])
tensor([[2., 3.],
        [4., 6.]])

這是對二維矩陣而言的,假如參與運算的是一個多維張量,那麼只有torch.matmul()可以使用。等等,多維張量怎麼進行矩陣的懲罰?在多維張量中,參與矩陣運算的其實只有後兩個維度,前面的維度其實就像是索引一樣,舉個例子:

a = torch.rand((1,2,64,32))
b = torch.rand((1,2,32,64))
print(torch.matmul(a, b).shape)
>>> torch.Size([1, 2, 64, 64])

可以看到,其實矩陣乘法的時候,看後兩個維度:\(64 \times 32\) 乘上 \(32 \times 64\),得到一個\(64 \times 64\)的矩陣。前面的維度要求相同,像是索引一樣,決定哪兩個\(64 \times 32\)\(32 \times 64\)相乘。

小提示:

a = torch.rand((3,2,64,32))
b = torch.rand((1,2,32,64))
print(torch.matmul(a, b).shape)
>>> torch.Size([3, 2, 64, 64])

這樣也是可以相乘的,因爲這裏涉及一個自動傳播Broadcasting機制,這個在後面會講,這裏就知道,如果這種情況下,會把b的第一維度複製3次 ,然後變成和a一樣的尺寸,進行矩陣相乘。

4 冪與開方

print('冪運算')
a = torch.tensor([1.,2.])
b = torch.tensor([2.,3.])
c1 = a ** b
c2 = torch.pow(a, b)
print(c1,c2)
>>> tensor([1., 8.]) tensor([1., 8.])

和上面一樣,不多說了。
開方運算可以用torch.sqrt(),當然也可以用a**(0.5)。

5 對數運算

在上學的時候,我們知道ln是以e爲底的,但是在pytorch中,並不是這樣

pytorch中log是以e自然數爲底數的,然後log2和log10纔是以2和10爲底數的運算。

import numpy as np
print('對數運算')
a = torch.tensor([2,10,np.e])
print(torch.log(a))
print(torch.log2(a))
print(torch.log10(a))
>>> tensor([0.6931, 2.3026, 1.0000])
>>> tensor([1.0000, 3.3219, 1.4427])
>>> tensor([0.3010, 1.0000, 0.4343]) 

6 近似值運算

  • .ceil() 向上取整
  • .floor()向下取整
  • .trunc()取整數
  • .frac()取小數
  • .round()四捨五入
a = torch.tensor(1.2345)
print(a.ceil())
>>>tensor(2.)
print(a.floor())
>>> tensor(1.)
print(a.trunc())
>>> tensor(1.)
print(a.frac())
>>> tensor(0.2345)
print(a.round())
>>> tensor(1.)

7 剪裁運算

這個是讓一個數,限制在你自己設置的一個範圍內[min,max],小於min的話就被設置爲min,大於max的話就被設置爲max。這個操作在一些對抗生成網絡中,好像是WGAN-GP,通過強行限制模型的參數的值。

a = torch.rand(5)
print(a)
print(a.clamp(0.3,0.7))

輸出爲:

tensor([0.5271, 0.6924, 0.9919, 0.0095, 0.0340])
tensor([0.5271, 0.6924, 0.7000, 0.3000, 0.3000])
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章