09模型創建步驟與nn.Module

一、網絡模型創建步驟

1.1 模型訓練步驟

  • 數據
  • 模型
  • 損失函數
  • 優化器
  • 迭代訓練

1.2 模型創建步驟

在這裏插入圖片描述

1.3 模型構建兩要素:

在這裏插入圖片描述

1.4 模型創建示例——LeNet

LeNet模型結構圖:
在這裏插入圖片描述
LeNet計算圖:
在這裏插入圖片描述
LeNet模型部分代碼:

class LeNet(nn.Module):
    def __init__(self, classes):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, classes)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out

說明:在__init__()中實現子模塊的構建,在forward中實現子模塊的拼接,所以在pytorch中,前向傳播過程就是子模塊的拼接過程

二、nn.Module

2.1 torch.nn

torch.nn:Pytroch中的神經網絡模塊,主要包括以下四個子模塊
在這裏插入圖片描述

2.2 nn.Module

  • parameters: 存儲管理nn.Parameter類
  • modules : 存儲管理nn.Module類
  • buffers: 存儲管理緩衝屬性, 如BN層中的running_mean
  • ***_hooks: 存儲管理鉤子函數

8個有序字典:
在這裏插入圖片描述

說明: 自定義模型時,init()方法會繼承父類nn.Module的__init__()函數

class LeNet(nn.Module):
    def __init__(self, classes):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, classes)

在nn.module的__init__()方法中,主要是_construct()

    def __init__(self):
        self._construct()
        # initialize self.training separately from the rest of the internal
        # state, as it is managed differently by nn.Module and ScriptModule
        self.training = True

_construct()創建了8個有序字典

    def _construct(self):
        """
        Initializes internal Module state, shared by both nn.Module and ScriptModule.
        """
        torch._C._log_api_usage_once("python.nn_module")
        self._backend = thnn_backend
        self._parameters = OrderedDict()
        self._buffers = OrderedDict()
        self._backward_hooks = OrderedDict()
        self._forward_hooks = OrderedDict()
        self._forward_pre_hooks = OrderedDict()
        self._state_dict_hooks = OrderedDict()
        self._load_state_dict_pre_hooks = OrderedDict()
        self._modules = OrderedDict()

說明: 自定義模型和其他網絡層模塊都直接或間接的繼承nn.Module,所以都會有這8個有序字典,而其子模塊,參數或者其他屬性會存儲在相應的有序字典中

2.3 nn.Module總結

  • 一個module可以包含多個子module
  • 一個module相當於一個運算,必須實現forward()函數
  • 每個module都有8個字典管理它的屬性
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章