Pytorch 學習(七):Pytorch 網絡模型創建

Pytorch 網絡模型創建

本方法總結自《動手學深度學習》(Pytorch版)github項目

常用的網絡搭建方法有

  • 繼承 Module 方法
  • 利用 Sequential, ModuleList 和 ModuleDict 類創建
  • 多種方法的同時使用

繼承 Module 方法

在 Pytorch 學習(五)中構建多層感知器網絡時,便使用了繼承 torch.nn.Module 的方法,這是最常用的網絡模型創建方法

import torch
import torch.nn as nn

class MLP(nn.Module):
    def __init__(self, n_i, n_h, n_o):
        super(MLP, self).__init__()
        self.linear1 = nn.Linear(n_i, n_h)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(n_h, n_o)

    def forward(self, input):
        return self.linear2(self.relu(self.linear1(input)))

利用 Sequential 類

同樣對於 MLP 網絡,可以使用 Sequential 類實現

from collections import OrderedDict

net = nn.Sequential(
  OrderedDict([
    ('linear1', nn.Linear(n_inputs, n_hiddens)),
    ('relu', nn.ReLU()),
    ('linear2', nn.Linear(n_hiddens, n_outputs))
    ])
  )

利用 Sequential 網絡的各層是有序的,同時不需要實現 forward 函數,默認按照對應順序進行前向傳播。構造一個 MySequential 類來進一步理解

class MySequential(nn.Module):
    def __init__(self, *args):
        super(MySequential, self).__init__()
        if len(args) == 1 and isinstance(args[0], OrderedDict):  # 如果傳入 OrderedDict 參數
            for key, module in args[0].items():
                self.add_module(key, module)
        else:  # 傳入的爲 module
            for idx, module in enumerate(args):
                self.add_module(str(idx), module)

    def forward(self, input):
        for module in self._modules.values():
            input = module(input)
        return input

同樣利用 MySequential 創建 MLP 網絡

net = MySequential(
  nn.Linear(n_inputs, n_hiddens),
  nn.ReLU(),
  nn.Linear(n_hiddens, n_outputs)
  )
print(net)

net = MySequential(
  OrderedDict([
    ('linear1', nn.Linear(n_inputs, n_hiddens)),
    ('relu', nn.ReLU()),
    ('linear2', nn.Linear(n_hiddens, n_outputs))
    ])
  )
print(net)

利用 ModuleList 和 ModuleDict

ModuleList 和 ModuleDict 的區別是輸入分別爲 list 和 dict,訪問某一層或添加更多層的方式略有區別

  • ModuleList 操作
net = nn.ModuleList([
  nn.Linear(n_inputs, n_hiddens),
  nn.ReLU(),
  ])
net.append(nn.Linear(n_hiddens, n_outputs))
print(net[-1])
  • ModuleDict 操作
net = nn.ModuleDict({
  'linear1', nn.Linear(n_inputs, n_hiddens),
  'relu', nn.ReLU()
  })
net['linear2'] = nn.Linear(n_hiddens, n_outputs)
print(net['linear2'])

這兩種方式都需要手動實現 forward 函數,不能直接調用。同時與傳統的 list 存在區別,後者的參數不計入 net.parameters()

self.linear1 = nn.ModuleList([nn.Linear(n_inputs, n_outputs)])
print(net.linear1.parameters()[0].shape)  # n_inputs, n_outputs

self.linear1 = [nn.Linear(n_inputs, n_outputs)]
print(net.parameters())  # None

總結

  • 四種方法實現模型構建:繼承 Module, 使用 Sequential, 使用 ModuleList 和使用 ModuleDict
  • 繼承 Module 靈活性最高
  • Sequential 有序、自動計算 forward 過程
  • ModuleList 和 ModuleDict 僅作爲容器,需要構建 forward 過程
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章