Module.children() vs Module.modules()

一句话:

Module.modules(): 采用了深度优先遍历的方式,如果想递归的遍历所有的模块,包括自己,使用 Module.modules().

Module.children(): 如果只想遍历儿子这一代模块, 使用 Module.children().

详细的解释来源如下:

链接: https://discuss.pytorch.org/t/module-children-vs-module-modules/4551/3

 更为直观的解释来源如下,其中部分代码和网络结构图来源于以下链接

链接: https://blog.csdn.net/dss_dssssd/article/details/83958518

代码如下:

import torch
from torch import nn

# hyper parameters
in_dim = 1
n_hidden_1 = 1
n_hidden_2 = 1
out_dim = 1

class Net(nn.Module):
    def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
        super().__init__()

        self.layer = nn.Sequential(
            nn.Linear(in_dim, n_hidden_1),
            nn.ReLU(True)
        )
        self.layer2 = nn.Sequential(
            nn.Linear(n_hidden_1, n_hidden_2),
            nn.ReLU(True),
        )
        self.layer3 = nn.Linear(n_hidden_2, out_dim)
        print("******result of the self.children()******")
        for i, module in enumerate(self.children()):
            print(i, module)

        print("******result of the self.modules()******")
        for i, module in enumerate(self.modules()):
            print(i, module)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x

all_layer = []

# get these indiviual single components and get rid of those Sequential components
def remove_sequential(Net):
    for layer in Net.children():
        if type(layer) == nn.Sequential:
            remove_sequential(layer)
        if list(layer.children()) == [ ]:
            all_layer.append(layer)

model = Net(in_dim, n_hidden_1, n_hidden_2, out_dim)

remove_sequential(model)
print('***************************')
print(all_layer)

 代码运行结果如下:

******result of the self.children()******
0 Sequential(
  (0): Linear(in_features=1, out_features=1, bias=True)
  (1): ReLU(inplace)
)
1 Sequential(
  (0): Linear(in_features=1, out_features=1, bias=True)
  (1): ReLU(inplace)
)
2 Linear(in_features=1, out_features=1, bias=True)
******result of the self.modules()******
0 Net(
  (layer): Sequential(
    (0): Linear(in_features=1, out_features=1, bias=True)
    (1): ReLU(inplace)
  )
  (layer2): Sequential(
    (0): Linear(in_features=1, out_features=1, bias=True)
    (1): ReLU(inplace)
  )
  (layer3): Linear(in_features=1, out_features=1, bias=True)
)
1 Sequential(
  (0): Linear(in_features=1, out_features=1, bias=True)
  (1): ReLU(inplace)
)
2 Linear(in_features=1, out_features=1, bias=True)
3 ReLU(inplace)
4 Sequential(
  (0): Linear(in_features=1, out_features=1, bias=True)
  (1): ReLU(inplace)
)
5 Linear(in_features=1, out_features=1, bias=True)
6 ReLU(inplace)
7 Linear(in_features=1, out_features=1, bias=True)
***************************
[Linear(in_features=1, out_features=1, bias=True), ReLU(inplace), Linear(in_features=1, out_features=1, bias=True), ReLU(inplace), Linear(in_features=1, out_features=1, bias=True)]

代码的网络结构图:

注: 近期在用Pytorch做项目时,遇见了self.children()和self.modules(), 不知所然. 感谢上述链接给了清晰明了的解释. 本次博客只是为了记录,以便以后能时常复习.

 

 

发布了53 篇原创文章 · 获赞 29 · 访问量 2万+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章