常用:conv+bn+relu組合
#conv
nn.init.kaiming_normal_(conv.weight, mode = 'fan_in')
nn.init.constant_(conv.bias, 0.) #如果conv後面有bn, bias=False
#bn
nn.init.normal_(bn.weight, mean = 1., std = 0.02)
nn.init.constant_(bn.bias, 0.)
#fc
nn.init.kaiming_normal(fc.weight, mode = 'fan_out')
nn.init.constant_(fc.bias, 0.)
resnet:
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
inception:
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
import scipy.stats as stats
stddev = m.stddev if hasattr(m, 'stddev') else 0.1
X = stats.truncnorm(-2, 2, scale=stddev)
values = torch.Tensor(X.rvs(m.weight.data.numel()))
values = values.view(m.weight.data.size())
m.weight.data.copy_(values)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
vgg:
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()