PyTorch - torch.max、torch.min
flyfish
torch.max
找出輸入張量中所有元素的最大值
參數
input(張量)–輸入張量。
dim(int)–要減少的維。
keepdim(布爾 bool)–輸出張量是否保持dim。 默認值:False。
out(元組 tuple,可選)–兩個輸出張量的結果元組(max,max_indices)
import torch
# 一維的情況
a = torch.randn(1, 3)
print(a)#tensor([[ 0.4084, 0.0232, -0.7173]])
print(torch.max(a))#tensor(0.4084)
#二維的情況
a = torch.randn(4, 3)
print(a)
# tensor([[-0.9843, -0.1896, -0.9875],
# [-0.2581, -0.2075, 1.0063],
# [-0.8443, -0.6996, 0.0647],
# [ 2.0349, -0.8708, 0.8848]])
print("0 False:",torch.max(a, dim=0,keepdim=False))#輸出結果是3個元素,shape是[3]
# 0 False: torch.return_types.max(
# values=tensor([ 2.0349, -0.1896, 1.0063]),
# indices=tensor([3, 0, 1]))
print("1 False:",torch.max(a, dim=1,keepdim=False))#輸出結果是4個元素shape是[4]
# 1 False: torch.return_types.max(
# values=tensor([-0.1896, 1.0063, 0.0647, 2.0349]),
# indices=tensor([1, 2, 2, 0]))
print("0 True:",torch.max(a, dim=0,keepdim=True))# shape是[1,3]
# 0 True: torch.return_types.max(
# values=tensor([[ 2.0349, -0.1896, 1.0063]]),
# indices=tensor([[3, 0, 1]]))
print("1 Ture:",torch.max(a, dim=1,keepdim=True))# shape是 [4,1]
# 1 Ture: torch.return_types.max(
# values=tensor([[-0.1896],
# [ 1.0063],
# [ 0.0647],
# [ 2.0349]]),
# indices=tensor([[1],
# [2],
# [2],
# [0]]))
另一中用法
a.max(1, keepdim=True) 與上面的torch.max(a, dim=1,keepdim=True) 是一樣的
維度從0開始
二維的時候 我們可以用 行和列 來表示
4行 3列
dim =0 相當於 減少0維,找出每列最大的,
dim =1 表示 減少1維 ,找出每行最大的
返回兩個類型,一個是值,一個是索引
如果遇到torch.min 變成找出輸入張量中所有元素的最小值