Module.children() vs Module.modules()

一句話:

Module.modules(): 採用了深度優先遍歷的方式,如果想遞歸的遍歷所有的模塊,包括自己,使用 Module.modules().

Module.children(): 如果只想遍歷兒子這一代模塊, 使用 Module.children().

詳細的解釋來源如下:

鏈接: https://discuss.pytorch.org/t/module-children-vs-module-modules/4551/3

 更爲直觀的解釋來源如下,其中部分代碼和網絡結構圖來源於以下鏈接

鏈接: https://blog.csdn.net/dss_dssssd/article/details/83958518

代碼如下:

import torch
from torch import nn

# hyper parameters
in_dim = 1
n_hidden_1 = 1
n_hidden_2 = 1
out_dim = 1

class Net(nn.Module):
    def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
        super().__init__()

        self.layer = nn.Sequential(
            nn.Linear(in_dim, n_hidden_1),
            nn.ReLU(True)
        )
        self.layer2 = nn.Sequential(
            nn.Linear(n_hidden_1, n_hidden_2),
            nn.ReLU(True),
        )
        self.layer3 = nn.Linear(n_hidden_2, out_dim)
        print("******result of the self.children()******")
        for i, module in enumerate(self.children()):
            print(i, module)

        print("******result of the self.modules()******")
        for i, module in enumerate(self.modules()):
            print(i, module)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x

all_layer = []

# get these indiviual single components and get rid of those Sequential components
def remove_sequential(Net):
    for layer in Net.children():
        if type(layer) == nn.Sequential:
            remove_sequential(layer)
        if list(layer.children()) == [ ]:
            all_layer.append(layer)

model = Net(in_dim, n_hidden_1, n_hidden_2, out_dim)

remove_sequential(model)
print('***************************')
print(all_layer)

 代碼運行結果如下:

******result of the self.children()******
0 Sequential(
  (0): Linear(in_features=1, out_features=1, bias=True)
  (1): ReLU(inplace)
)
1 Sequential(
  (0): Linear(in_features=1, out_features=1, bias=True)
  (1): ReLU(inplace)
)
2 Linear(in_features=1, out_features=1, bias=True)
******result of the self.modules()******
0 Net(
  (layer): Sequential(
    (0): Linear(in_features=1, out_features=1, bias=True)
    (1): ReLU(inplace)
  )
  (layer2): Sequential(
    (0): Linear(in_features=1, out_features=1, bias=True)
    (1): ReLU(inplace)
  )
  (layer3): Linear(in_features=1, out_features=1, bias=True)
)
1 Sequential(
  (0): Linear(in_features=1, out_features=1, bias=True)
  (1): ReLU(inplace)
)
2 Linear(in_features=1, out_features=1, bias=True)
3 ReLU(inplace)
4 Sequential(
  (0): Linear(in_features=1, out_features=1, bias=True)
  (1): ReLU(inplace)
)
5 Linear(in_features=1, out_features=1, bias=True)
6 ReLU(inplace)
7 Linear(in_features=1, out_features=1, bias=True)
***************************
[Linear(in_features=1, out_features=1, bias=True), ReLU(inplace), Linear(in_features=1, out_features=1, bias=True), ReLU(inplace), Linear(in_features=1, out_features=1, bias=True)]

代碼的網絡結構圖:

注: 近期在用Pytorch做項目時,遇見了self.children()和self.modules(), 不知所然. 感謝上述鏈接給了清晰明瞭的解釋. 本次博客只是爲了記錄,以便以後能時常複習.

 

 

發佈了53 篇原創文章 · 獲贊 29 · 訪問量 2萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章