一、網絡模型創建步驟
1.1 模型訓練步驟
- 數據
- 模型
- 損失函數
- 優化器
- 迭代訓練
1.2 模型創建步驟
1.3 模型構建兩要素:
1.4 模型創建示例——LeNet
LeNet模型結構圖:
LeNet計算圖:
LeNet模型部分代碼:
class LeNet(nn.Module):
def __init__(self, classes):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, classes)
def forward(self, x):
out = F.relu(self.conv1(x))
out = F.max_pool2d(out, 2)
out = F.relu(self.conv2(out))
out = F.max_pool2d(out, 2)
out = out.view(out.size(0), -1)
out = F.relu(self.fc1(out))
out = F.relu(self.fc2(out))
out = self.fc3(out)
return out
說明:在__init__()中實現子模塊的構建,在forward中實現子模塊的拼接,所以在pytorch中,前向傳播過程就是子模塊的拼接過程
二、nn.Module
2.1 torch.nn
torch.nn:Pytroch中的神經網絡模塊,主要包括以下四個子模塊
2.2 nn.Module
- parameters: 存儲管理nn.Parameter類
- modules : 存儲管理nn.Module類
- buffers: 存儲管理緩衝屬性, 如BN層中的running_mean
- ***_hooks: 存儲管理鉤子函數
8個有序字典:
說明: 自定義模型時,init()方法會繼承父類nn.Module的__init__()函數
class LeNet(nn.Module):
def __init__(self, classes):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, classes)
在nn.module的__init__()方法中,主要是_construct()
def __init__(self):
self._construct()
# initialize self.training separately from the rest of the internal
# state, as it is managed differently by nn.Module and ScriptModule
self.training = True
_construct()創建了8個有序字典
def _construct(self):
"""
Initializes internal Module state, shared by both nn.Module and ScriptModule.
"""
torch._C._log_api_usage_once("python.nn_module")
self._backend = thnn_backend
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._state_dict_hooks = OrderedDict()
self._load_state_dict_pre_hooks = OrderedDict()
self._modules = OrderedDict()
說明: 自定義模型和其他網絡層模塊都直接或間接的繼承nn.Module,所以都會有這8個有序字典,而其子模塊,參數或者其他屬性會存儲在相應的有序字典中
2.3 nn.Module總結
- 一個module可以包含多個子module
- 一個module相當於一個運算,必須實現forward()函數
- 每個module都有8個字典管理它的屬性