python中numpy的axis和torch的dim

舉一個例子:

import torch
A = torch.rand((3,4))
print(A)
#tensor([[0.3602, 0.2583, 0.1758, 0.3575],
#        [0.9582, 0.2092, 0.6829, 0.8663],
#        [0.3922, 0.1360, 0.3733, 0.3477]])
z = A.sum(dim=1, keepdim=True)
print(z)
#tensor([[1.1518],
#        [2.7166],
#        [1.2492]])
y = A.sum(dim=1)
print(y)
#tensor([1.1518, 2.7166, 1.2492])

我們看keepdim=True的情形(此時最清楚,沒有做置換強行改變行列下標),dim=1就相當於sum是在dim=1上做,於是dim=1的下標就沒有了,只剩下dim=0的下標。即,z的元素下標爲 [0,0],[1,0],[2,0],dim=1的下標全部爲0。

從這個角度看,不用畫圖,只考慮運算。另外,也可以認爲sum(dim=1),就是dim=1\sum\limits_{dim=1}

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章