PyTorch - torch.max、torch.min

PyTorch - torch.max、torch.min

flyfish

torch.max

找出輸入張量中所有元素的最大值
參數
input(張量)–輸入張量。
dim(int)–要減少的維。
keepdim(布爾 bool)–輸出張量是否保持dim。 默認值:False。
out(元組 tuple,可選)–兩個輸出張量的結果元組(max,max_indices)

import torch
# 一維的情況
a = torch.randn(1, 3)
print(a)#tensor([[ 0.4084,  0.0232, -0.7173]])
print(torch.max(a))#tensor(0.4084)

#二維的情況
a = torch.randn(4, 3)
print(a)

# tensor([[-0.9843, -0.1896, -0.9875],
#         [-0.2581, -0.2075,  1.0063],
#         [-0.8443, -0.6996,  0.0647],
#         [ 2.0349, -0.8708,  0.8848]])

print("0 False:",torch.max(a, dim=0,keepdim=False))#輸出結果是3個元素,shape是[3]
# 0 False: torch.return_types.max(
# values=tensor([ 2.0349, -0.1896,  1.0063]),
# indices=tensor([3, 0, 1]))

print("1 False:",torch.max(a, dim=1,keepdim=False))#輸出結果是4個元素shape是[4]
# 1 False: torch.return_types.max(
# values=tensor([-0.1896,  1.0063,  0.0647,  2.0349]),
# indices=tensor([1, 2, 2, 0]))

print("0 True:",torch.max(a, dim=0,keepdim=True))# shape是[1,3]
# 0 True: torch.return_types.max(
# values=tensor([[ 2.0349, -0.1896,  1.0063]]),
# indices=tensor([[3, 0, 1]]))

print("1 Ture:",torch.max(a, dim=1,keepdim=True))# shape是 [4,1]
# 1 Ture: torch.return_types.max(
# values=tensor([[-0.1896],
#         [ 1.0063],
#         [ 0.0647],
#         [ 2.0349]]),
# indices=tensor([[1],
#         [2],
#         [2],
#         [0]]))

另一中用法
a.max(1, keepdim=True) 與上面的torch.max(a, dim=1,keepdim=True) 是一樣的

維度從0開始
二維的時候 我們可以用 行和列 來表示
4行 3列
dim =0 相當於 減少0維,找出每列最大的,
dim =1 表示 減少1維 ,找出每行最大的
返回兩個類型,一個是值,一個是索引
如果遇到torch.min 變成找出輸入張量中所有元素的最小值

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章