一句話:
Module.modules(): 採用了深度優先遍歷的方式,如果想遞歸的遍歷所有的模塊,包括自己,使用 Module.modules().
Module.children(): 如果只想遍歷兒子這一代模塊, 使用 Module.children().
詳細的解釋來源如下:
鏈接: https://discuss.pytorch.org/t/module-children-vs-module-modules/4551/3
更爲直觀的解釋來源如下,其中部分代碼和網絡結構圖來源於以下鏈接
鏈接: https://blog.csdn.net/dss_dssssd/article/details/83958518
代碼如下:
import torch
from torch import nn
# hyper parameters
in_dim = 1
n_hidden_1 = 1
n_hidden_2 = 1
out_dim = 1
class Net(nn.Module):
def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
super().__init__()
self.layer = nn.Sequential(
nn.Linear(in_dim, n_hidden_1),
nn.ReLU(True)
)
self.layer2 = nn.Sequential(
nn.Linear(n_hidden_1, n_hidden_2),
nn.ReLU(True),
)
self.layer3 = nn.Linear(n_hidden_2, out_dim)
print("******result of the self.children()******")
for i, module in enumerate(self.children()):
print(i, module)
print("******result of the self.modules()******")
for i, module in enumerate(self.modules()):
print(i, module)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x
all_layer = []
# get these indiviual single components and get rid of those Sequential components
def remove_sequential(Net):
for layer in Net.children():
if type(layer) == nn.Sequential:
remove_sequential(layer)
if list(layer.children()) == [ ]:
all_layer.append(layer)
model = Net(in_dim, n_hidden_1, n_hidden_2, out_dim)
remove_sequential(model)
print('***************************')
print(all_layer)
代碼運行結果如下:
******result of the self.children()******
0 Sequential(
(0): Linear(in_features=1, out_features=1, bias=True)
(1): ReLU(inplace)
)
1 Sequential(
(0): Linear(in_features=1, out_features=1, bias=True)
(1): ReLU(inplace)
)
2 Linear(in_features=1, out_features=1, bias=True)
******result of the self.modules()******
0 Net(
(layer): Sequential(
(0): Linear(in_features=1, out_features=1, bias=True)
(1): ReLU(inplace)
)
(layer2): Sequential(
(0): Linear(in_features=1, out_features=1, bias=True)
(1): ReLU(inplace)
)
(layer3): Linear(in_features=1, out_features=1, bias=True)
)
1 Sequential(
(0): Linear(in_features=1, out_features=1, bias=True)
(1): ReLU(inplace)
)
2 Linear(in_features=1, out_features=1, bias=True)
3 ReLU(inplace)
4 Sequential(
(0): Linear(in_features=1, out_features=1, bias=True)
(1): ReLU(inplace)
)
5 Linear(in_features=1, out_features=1, bias=True)
6 ReLU(inplace)
7 Linear(in_features=1, out_features=1, bias=True)
***************************
[Linear(in_features=1, out_features=1, bias=True), ReLU(inplace), Linear(in_features=1, out_features=1, bias=True), ReLU(inplace), Linear(in_features=1, out_features=1, bias=True)]
代碼的網絡結構圖:
注: 近期在用Pytorch做項目時,遇見了self.children()和self.modules(), 不知所然. 感謝上述鏈接給了清晰明瞭的解釋. 本次博客只是爲了記錄,以便以後能時常複習.