一句话:
Module.modules(): 采用了深度优先遍历的方式,如果想递归的遍历所有的模块,包括自己,使用 Module.modules().
Module.children(): 如果只想遍历儿子这一代模块, 使用 Module.children().
详细的解释来源如下:
链接: https://discuss.pytorch.org/t/module-children-vs-module-modules/4551/3
更为直观的解释来源如下,其中部分代码和网络结构图来源于以下链接
链接: https://blog.csdn.net/dss_dssssd/article/details/83958518
代码如下:
import torch
from torch import nn
# hyper parameters
in_dim = 1
n_hidden_1 = 1
n_hidden_2 = 1
out_dim = 1
class Net(nn.Module):
def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
super().__init__()
self.layer = nn.Sequential(
nn.Linear(in_dim, n_hidden_1),
nn.ReLU(True)
)
self.layer2 = nn.Sequential(
nn.Linear(n_hidden_1, n_hidden_2),
nn.ReLU(True),
)
self.layer3 = nn.Linear(n_hidden_2, out_dim)
print("******result of the self.children()******")
for i, module in enumerate(self.children()):
print(i, module)
print("******result of the self.modules()******")
for i, module in enumerate(self.modules()):
print(i, module)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x
all_layer = []
# get these indiviual single components and get rid of those Sequential components
def remove_sequential(Net):
for layer in Net.children():
if type(layer) == nn.Sequential:
remove_sequential(layer)
if list(layer.children()) == [ ]:
all_layer.append(layer)
model = Net(in_dim, n_hidden_1, n_hidden_2, out_dim)
remove_sequential(model)
print('***************************')
print(all_layer)
代码运行结果如下:
******result of the self.children()******
0 Sequential(
(0): Linear(in_features=1, out_features=1, bias=True)
(1): ReLU(inplace)
)
1 Sequential(
(0): Linear(in_features=1, out_features=1, bias=True)
(1): ReLU(inplace)
)
2 Linear(in_features=1, out_features=1, bias=True)
******result of the self.modules()******
0 Net(
(layer): Sequential(
(0): Linear(in_features=1, out_features=1, bias=True)
(1): ReLU(inplace)
)
(layer2): Sequential(
(0): Linear(in_features=1, out_features=1, bias=True)
(1): ReLU(inplace)
)
(layer3): Linear(in_features=1, out_features=1, bias=True)
)
1 Sequential(
(0): Linear(in_features=1, out_features=1, bias=True)
(1): ReLU(inplace)
)
2 Linear(in_features=1, out_features=1, bias=True)
3 ReLU(inplace)
4 Sequential(
(0): Linear(in_features=1, out_features=1, bias=True)
(1): ReLU(inplace)
)
5 Linear(in_features=1, out_features=1, bias=True)
6 ReLU(inplace)
7 Linear(in_features=1, out_features=1, bias=True)
***************************
[Linear(in_features=1, out_features=1, bias=True), ReLU(inplace), Linear(in_features=1, out_features=1, bias=True), ReLU(inplace), Linear(in_features=1, out_features=1, bias=True)]
代码的网络结构图:
注: 近期在用Pytorch做项目时,遇见了self.children()和self.modules(), 不知所然. 感谢上述链接给了清晰明了的解释. 本次博客只是为了记录,以便以后能时常复习.