pytorch獲取nn.Sequential的中間層輸出

對於nn.Sequential結構,要想獲取中間網絡層輸出,可以使用循環遍歷的方式得到。

示例

import torch
import torch.nn as nn
model = nn.Sequential(
            nn.Conv2d(3, 9, 1, 1, 0, bias=False),
            nn.BatchNorm2d(9),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1)),
        )
# 假如想要獲得ReLu的輸出
x = torch.rand([2, 3, 224, 224])
for i in range(len(model)):
    x = model[i](x)
    if i == 2:
        ReLu_out = x
print('ReLu_out.shape:\n\t',ReLu_out.shape)
print('x.shape:\n\t',x.shape)

結果

ReLu_out.shape:
  torch.Size([2, 9, 224, 224])
x.shape:
  torch.Size([2, 9, 1, 1])
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章