Python小練習:權重初始化(Weight Initialization)
作者:凱魯嘎吉 - 博客園 http://www.cnblogs.com/kailugaji/
調用Pytorch中的torch.nn.init.xxx實現對模型權重與偏置初始化。
1. weight_init_test.py
1 # -*- coding: utf-8 -*- 2 # Author:凱魯嘎吉 Coral Gajic 3 # https://www.cnblogs.com/kailugaji/ 4 # Python小練習:權重初始化(Weight Initialization) 5 # Custom weight init for Conv2D and Linear layers. 6 import torch 7 import torch.nn.functional as F 8 import torch.nn as nn 9 # 根據網絡層的不同定義不同的初始化方式 10 # 以下是兩種不同的初始化方式: 11 # 正態分佈+常數 12 def weight_init(m): 13 if isinstance(m, nn.Linear): 14 # 如果傳入的參數是 nn.Linear 類型,則執行以下操作: 15 nn.init.xavier_normal_(m.weight) # 將權重初始化爲 Xavier 正態分佈 16 nn.init.constant_(m.bias, 0) # 將權重初始化爲常數 17 elif isinstance(m, nn.Conv2d): 18 # 如果傳入的參數是 nn.Conv2d 類型,則執行以下操作: 19 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') # 將權重初始化爲正態分佈 20 elif isinstance(m, nn.BatchNorm2d): 21 # 如果傳入的參數是 nn.BatchNorm2d 類型,則執行以下操作: 22 nn.init.constant_(m.weight, 1) 23 nn.init.constant_(m.bias, 0) 24 25 # 正交+常數 26 def weight_init2(m): 27 if isinstance(m, nn.Linear): 28 # 如果傳入的參數是 nn.Linear 類型,則執行以下操作: 29 nn.init.orthogonal_(m.weight.data) # 對權重矩陣進行正交化操作,使其具有對稱性。 30 if hasattr(m.bias, 'data'): 31 m.bias.data.fill_(0.0) # 如果傳入的參數包含偏置項,則將其填充爲零。 32 elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 33 # 如果傳入的參數是 nn.Conv2d 或 nn.ConvTranspose2d 類型,則執行以下操作: 34 gain = nn.init.calculate_gain('relu') # 用於計算激活函數的增益 35 nn.init.orthogonal_(m.weight.data, gain) # 對權重矩陣進行正交化操作,使其具有對稱性。 36 if hasattr(m.bias, 'data'): 37 m.bias.data.fill_(0.0) # 如果傳入的參數包含偏置項,則將其填充爲零。 38 39 class Net(nn.Module): 40 def __init__(self, input_size=1): 41 self.input_size = input_size 42 super(Net, self).__init__() 43 self.fc1 = nn.Linear(self.input_size, 2) 44 self.fc2 = nn.Linear(2, 4) 45 self.fc3 = nn.Linear(4, 2) 46 47 def forward(self, x): 48 x = x.view(-1, self.input_size) 49 x = F.relu(self.fc1(x)) 50 x = F.relu(self.fc2(x)) 51 x = self.fc3(x) 52 return F.log_softmax(x, dim=1) 53 54 torch.manual_seed(1) 55 num = 4 # 輸入維度 56 x = torch.randn(1, num) 57 # 方式1: 58 model = Net(input_size = num) 59 print('網絡結構:\n', model) 60 print('輸入:\n', x) 61 model.apply(weight_init) 62 y = model(x) 63 print('輸出1:\n', y.data) 64 print('權重1:\n', model.fc1.weight.data) 65 # 方式2: 66 model = Net(input_size = num) 67 model.apply(weight_init2) 68 y = model(x) 69 print('輸出2:\n', y.data) 70 print('權重2:\n', model.fc1.weight.data)
2. 結果
D:\ProgramData\Anaconda3\python.exe "D:/Python code/2023.3 exercise/Neural Network/weight_init_test.py" 網絡結構: Net( (fc1): Linear(in_features=4, out_features=2, bias=True) (fc2): Linear(in_features=2, out_features=4, bias=True) (fc3): Linear(in_features=4, out_features=2, bias=True) ) 輸入: tensor([[0.6614, 0.2669, 0.0617, 0.6213]]) 輸出1: tensor([[-0.7233, -0.6639]]) 權重1: tensor([[ 2.0709, -1.0573, 0.9230, -0.7373], [ 0.1879, -0.2766, 0.7962, 1.4599]]) 輸出2: tensor([[-0.6951, -0.6912]]) 權重2: tensor([[-0.8471, -0.4721, 0.1653, 0.1795], [-0.4072, 0.5991, -0.6437, 0.2467]]) Process finished with exit code 0
完成。