Python小練習:權重初始化(Weight Initialization)

Python小練習:權重初始化(Weight Initialization)

作者:凱魯嘎吉 - 博客園 http://www.cnblogs.com/kailugaji/

調用Pytorch中的torch.nn.init.xxx實現對模型權重與偏置初始化。

1. weight_init_test.py

 1 # -*- coding: utf-8 -*-
 2 # Author:凱魯嘎吉 Coral Gajic
 3 # https://www.cnblogs.com/kailugaji/
 4 # Python小練習:權重初始化(Weight Initialization)
 5 # Custom weight init for Conv2D and Linear layers.
 6 import torch
 7 import torch.nn.functional as F
 8 import torch.nn as nn
 9 # 根據網絡層的不同定義不同的初始化方式
10 # 以下是兩種不同的初始化方式:
11 # 正態分佈+常數
12 def weight_init(m):
13     if isinstance(m, nn.Linear):
14         # 如果傳入的參數是 nn.Linear 類型,則執行以下操作:
15         nn.init.xavier_normal_(m.weight) # 將權重初始化爲 Xavier 正態分佈
16         nn.init.constant_(m.bias, 0) # 將權重初始化爲常數
17     elif isinstance(m, nn.Conv2d):
18         # 如果傳入的參數是 nn.Conv2d 類型,則執行以下操作:
19         nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') # 將權重初始化爲正態分佈
20     elif isinstance(m, nn.BatchNorm2d):
21         # 如果傳入的參數是 nn.BatchNorm2d 類型,則執行以下操作:
22         nn.init.constant_(m.weight, 1)
23         nn.init.constant_(m.bias, 0)
24 
25 # 正交+常數
26 def weight_init2(m):
27     if isinstance(m, nn.Linear):
28         # 如果傳入的參數是 nn.Linear 類型,則執行以下操作:
29         nn.init.orthogonal_(m.weight.data) # 對權重矩陣進行正交化操作,使其具有對稱性。
30         if hasattr(m.bias, 'data'):
31             m.bias.data.fill_(0.0) # 如果傳入的參數包含偏置項,則將其填充爲零。
32     elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
33         # 如果傳入的參數是 nn.Conv2d 或 nn.ConvTranspose2d 類型,則執行以下操作:
34         gain = nn.init.calculate_gain('relu') # 用於計算激活函數的增益
35         nn.init.orthogonal_(m.weight.data, gain) # 對權重矩陣進行正交化操作,使其具有對稱性。
36         if hasattr(m.bias, 'data'):
37             m.bias.data.fill_(0.0) # 如果傳入的參數包含偏置項,則將其填充爲零。
38 
39 class Net(nn.Module):
40     def __init__(self, input_size=1):
41         self.input_size = input_size
42         super(Net, self).__init__()
43         self.fc1 = nn.Linear(self.input_size, 2)
44         self.fc2 = nn.Linear(2, 4)
45         self.fc3 = nn.Linear(4, 2)
46 
47     def forward(self, x):
48         x = x.view(-1, self.input_size)
49         x = F.relu(self.fc1(x))
50         x = F.relu(self.fc2(x))
51         x = self.fc3(x)
52         return F.log_softmax(x, dim=1)
53 
54 torch.manual_seed(1)
55 num = 4 # 輸入維度
56 x = torch.randn(1, num)
57 # 方式1:
58 model = Net(input_size = num)
59 print('網絡結構:\n', model)
60 print('輸入:\n', x)
61 model.apply(weight_init)
62 y = model(x)
63 print('輸出1:\n', y.data)
64 print('權重1:\n', model.fc1.weight.data)
65 # 方式2:
66 model = Net(input_size = num)
67 model.apply(weight_init2)
68 y = model(x)
69 print('輸出2:\n', y.data)
70 print('權重2:\n', model.fc1.weight.data)

2. 結果

D:\ProgramData\Anaconda3\python.exe "D:/Python code/2023.3 exercise/Neural Network/weight_init_test.py"
網絡結構:
 Net(
  (fc1): Linear(in_features=4, out_features=2, bias=True)
  (fc2): Linear(in_features=2, out_features=4, bias=True)
  (fc3): Linear(in_features=4, out_features=2, bias=True)
)
輸入:
 tensor([[0.6614, 0.2669, 0.0617, 0.6213]])
輸出1:
 tensor([[-0.7233, -0.6639]])
權重1:
 tensor([[ 2.0709, -1.0573,  0.9230, -0.7373],
        [ 0.1879, -0.2766,  0.7962,  1.4599]])
輸出2:
 tensor([[-0.6951, -0.6912]])
權重2:
 tensor([[-0.8471, -0.4721,  0.1653,  0.1795],
        [-0.4072,  0.5991, -0.6437,  0.2467]])

Process finished with exit code 0

完成。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章