import torch.nn as nn
import torch
# 創建三維tensor
a = torch.randn(3,4,5)
print(a.shape)
print(a)
# 升維,升成四維
a = torch.unsqueeze(a, 0)
print(a.shape)
print(a)
# AdaptiveAvgPool2d(X) 是將W H 使用平均池化降爲X維
avg = nn.AdaptiveAvgPool2d(1)
b = avg(a)
print(a.shape)
print(b)
print(b.shape)