舉一個例子:
import torch
A = torch.rand((3,4))
print(A)
#tensor([[0.3602, 0.2583, 0.1758, 0.3575],
# [0.9582, 0.2092, 0.6829, 0.8663],
# [0.3922, 0.1360, 0.3733, 0.3477]])
z = A.sum(dim=1, keepdim=True)
print(z)
#tensor([[1.1518],
# [2.7166],
# [1.2492]])
y = A.sum(dim=1)
print(y)
#tensor([1.1518, 2.7166, 1.2492])
我們看keepdim=True的情形(此時最清楚,沒有做置換強行改變行列下標),dim=1就相當於sum是在dim=1上做,於是dim=1的下標就沒有了,只剩下dim=0的下標。即,z的元素下標爲 [0,0],[1,0],[2,0],dim=1的下標全部爲0。
從這個角度看,不用畫圖,只考慮運算。另外,也可以認爲sum(dim=1),就是