AdaptiveAvgPool2d 测试

import torch.nn as nn
import torch


# 创建三维tensor
a = torch.randn(3,4,5)
print(a.shape)
print(a)
# 升维,升成四维
a = torch.unsqueeze(a, 0)
print(a.shape)
print(a)
# AdaptiveAvgPool2d(X) 是将W H 使用平均池化降为X维
avg = nn.AdaptiveAvgPool2d(1)
b = avg(a)
print(a.shape)
print(b)
print(b.shape)

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章