Pytorch torch.mean()的簡單用法
簡單來說就是求平均數。
比如以下的三種簡單情況:
import torch x1 = torch.Tensor([1, 2, 3, 4]) x2 = torch.Tensor([[1], [2], [3], [4]]) x3 = torch.Tensor([[1, 2], [3, 4]]) y1 = torch.mean(x1) y2 = torch.mean(x2) y3 = torch.mean(x3) print(y1) print(y2) print(y3)
輸出:
tensor(2.5000)
tensor(2.5000)
tensor(2.5000)
也就是說,在沒有指定維度的情況下,就是對所有數進行求平均。
更多的時候用到的是有維度的情形,如:
二維張量求均值:
import torch x = torch.Tensor([1, 2, 3, 4, 5, 6]).view(2, 3) y_0 = torch.mean(x, dim=0) ## 每列求均值 y_1 = torch.mean(x, dim=1) ### 每行求均值 print(x) print(y_0) print(y_1)
輸出:
tensor([[1., 2., 3.],
[4., 5., 6.]])
tensor([2.5000, 3.5000, 4.5000])
tensor([2., 5.])
輸入tensor的形狀爲(2, 3),其中2爲第0維,3爲第1維。對第0維求平均,得到的結果爲形狀爲(1, 3)的tensor;對第1維求平均,得到的結果爲形狀爲(2, 1)的tensor。
可以理解爲,對哪一維做平均,就是將該維所有的數做平均,壓扁成1層(實際上這一層就給合併掉了,比如上面的例子,2維的tensor在求平均數後變成了1維),而其他維的形狀不影響。
如果要保持維度不變(例如在深度網絡中),則可以加上參數keepdim=True:
y = torch.mean(x, dim=1, keepdim=True)
三維張量求均值:
import torch import numpy as np # ======初始化一個三維矩陣===== A = torch.ones((4,3,2)) # ======替換三維矩陣裏面的值====== A[0] = torch.ones((3,2)) *1 A[1] = torch.ones((3,2)) *2 A[2] = torch.ones((3,2)) *3 A[3] = torch.ones((3,2)) *4 print(A) B = torch.mean(A ,dim=0) print(B) B = torch.mean(A ,dim=1) print(B) B = torch.mean(A ,dim=2) print(B)
輸出結果
tensor([[[1., 1.], [1., 1.], [1., 1.]], [[2., 2.], [2., 2.], [2., 2.]], [[3., 3.], [3., 3.], [3., 3.]], [[4., 4.], [4., 4.], [4., 4.]]]) tensor([[2.5000, 2.5000], [2.5000, 2.5000], [2.5000, 2.5000]]) tensor([[1., 1.], [2., 2.], [3., 3.], [4., 4.]]) tensor([[1., 1., 1.], [2., 2., 2.], [3., 3., 3.], [4., 4., 4.]])
REF
https://blog.csdn.net/qq_40714949/article/details/115485140