Pytorch 学习(七):Pytorch 网络模型创建

Pytorch 网络模型创建

本方法总结自《动手学深度学习》(Pytorch版)github项目

常用的网络搭建方法有

  • 继承 Module 方法
  • 利用 Sequential, ModuleList 和 ModuleDict 类创建
  • 多种方法的同时使用

继承 Module 方法

在 Pytorch 学习(五)中构建多层感知器网络时,便使用了继承 torch.nn.Module 的方法,这是最常用的网络模型创建方法

import torch
import torch.nn as nn

class MLP(nn.Module):
    def __init__(self, n_i, n_h, n_o):
        super(MLP, self).__init__()
        self.linear1 = nn.Linear(n_i, n_h)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(n_h, n_o)

    def forward(self, input):
        return self.linear2(self.relu(self.linear1(input)))

利用 Sequential 类

同样对于 MLP 网络,可以使用 Sequential 类实现

from collections import OrderedDict

net = nn.Sequential(
  OrderedDict([
    ('linear1', nn.Linear(n_inputs, n_hiddens)),
    ('relu', nn.ReLU()),
    ('linear2', nn.Linear(n_hiddens, n_outputs))
    ])
  )

利用 Sequential 网络的各层是有序的,同时不需要实现 forward 函数,默认按照对应顺序进行前向传播。构造一个 MySequential 类来进一步理解

class MySequential(nn.Module):
    def __init__(self, *args):
        super(MySequential, self).__init__()
        if len(args) == 1 and isinstance(args[0], OrderedDict):  # 如果传入 OrderedDict 参数
            for key, module in args[0].items():
                self.add_module(key, module)
        else:  # 传入的为 module
            for idx, module in enumerate(args):
                self.add_module(str(idx), module)

    def forward(self, input):
        for module in self._modules.values():
            input = module(input)
        return input

同样利用 MySequential 创建 MLP 网络

net = MySequential(
  nn.Linear(n_inputs, n_hiddens),
  nn.ReLU(),
  nn.Linear(n_hiddens, n_outputs)
  )
print(net)

net = MySequential(
  OrderedDict([
    ('linear1', nn.Linear(n_inputs, n_hiddens)),
    ('relu', nn.ReLU()),
    ('linear2', nn.Linear(n_hiddens, n_outputs))
    ])
  )
print(net)

利用 ModuleList 和 ModuleDict

ModuleList 和 ModuleDict 的区别是输入分别为 list 和 dict,访问某一层或添加更多层的方式略有区别

  • ModuleList 操作
net = nn.ModuleList([
  nn.Linear(n_inputs, n_hiddens),
  nn.ReLU(),
  ])
net.append(nn.Linear(n_hiddens, n_outputs))
print(net[-1])
  • ModuleDict 操作
net = nn.ModuleDict({
  'linear1', nn.Linear(n_inputs, n_hiddens),
  'relu', nn.ReLU()
  })
net['linear2'] = nn.Linear(n_hiddens, n_outputs)
print(net['linear2'])

这两种方式都需要手动实现 forward 函数,不能直接调用。同时与传统的 list 存在区别,后者的参数不计入 net.parameters()

self.linear1 = nn.ModuleList([nn.Linear(n_inputs, n_outputs)])
print(net.linear1.parameters()[0].shape)  # n_inputs, n_outputs

self.linear1 = [nn.Linear(n_inputs, n_outputs)]
print(net.parameters())  # None

总结

  • 四种方法实现模型构建:继承 Module, 使用 Sequential, 使用 ModuleList 和使用 ModuleDict
  • 继承 Module 灵活性最高
  • Sequential 有序、自动计算 forward 过程
  • ModuleList 和 ModuleDict 仅作为容器,需要构建 forward 过程
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章