Pytorch中 nn.ModuleList 與nn.Sequential 的區別

nn.ModuleList

  • Class torch.nn.ModuleList(modules=None)
    簡單的說,就是把子模塊存儲在list中.它類似於list, 既可以 append 操作,也可以做 insert 操作,也可以 extend 操作. 但是由於把layers存入Modulelist中後只是完成了存儲作用,所以不能直接在forward中直接運行,需要通過索引調出相應的submodule.

    class MyModule(nn.Module):
        def __init__(self):
            super(MyModule, self).__init__()
            self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
            # extend操作
    		self.linears.extend([nn.Linear(10, 10),nn.Linear(10, 10) ])
    		# append操作
        	self.linears.append(nn.Linear(10, 10))
        def forward(self, x):
            # ModuleList can act as an iterable, or be indexed using ints
            for i, l in enumerate(self.linears):
                x = self.linears[i // 2](x) + l(x)
            return x
    

nn.Sequential

  • Class torch.nn.Sequential(*args)
    順序容器.模塊將按照順序存進sequential中,相當於一個包裝起來的子模塊集,可以在forward中直接運行.

    class MyModule(nn.Module):
        def __init__(self):
            super(MyModule, self).__init__()
            self.layers = nn.Sequential(
              nn.Conv2d(1,20,5),
              nn.ReLU(),
              nn.Conv2d(20,64,5),
              nn.ReLU()
            )
        def forward(self, x):
            # ModuleList can act as an iterable, or be indexed using ints
            x = self.layers
            return x
    
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章