import torch.nn as nn import torch.nn.functional as F def initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.normal_(m.weight.data, 0, 0.01) m.bias.data.zero_() class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(2, 10) self.fc2 = nn.Linear(10, 10) self.fc3 = nn.Linear(10, 1) def forward(self, x): x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x model = Net() initialize_weights(model) for layer in model.modules(): if isinstance(layer, nn.Linear): print('weight = {}'.format(layer.weight)) print('bias = {}'.format(layer.bias))
pytorch模型的數據初始化代碼
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.