模块初始化
class IDE(nn.Module):
def __init__(self, num_classes):
super(IDE, self).__init__()
resnet = resnet50(pretrained=True)
self.backbone = nn.Sequential(
resnet.conv1,
resnet.bn1,
resnet.relu,
resnet.maxpool,
resnet.layer1,
resnet.layer2,
resnet.layer3,
resnet.layer4,
nn.AdaptiveAvgPool2d((1, 1)),
)
self.classifier = nn.Sequential(
nn.Linear(2048, 512),
nn.BatchNorm1d(512),
nn.LeakyReLU(0.1),
nn.Dropout(p=0.5),
nn.Linear(512, num_classes)
)
nn.init.kaiming_normal_(self.classifier[0].weight, mode='fan_out')
nn.init.constant_(self.classifier[0].bias, 0.)
nn.init.normal_(self.classifier[1].weight, mean=1., std=0.02)
nn.init.constant_(self.classifier[1].bias, 0.)
nn.init.normal_(self.classifier[4].weight, std=0.001)
nn.init.constant_(self.classifier[4].bias, 0.)
def forward(self, x):
"""
:param x: input image of (N, C, H, W)
:return: (feature of N*2048, label predict of N*num_classes)
"""
x = self.backbone(x)
x = x.squeeze()
y = self.classifier(x)
return x, y
单层初始化
##########初始化FC层 ###################单层初始化
# fc softmax loss
self.fc_id_2048_0 = nn.Linear(2048, num_classes)
self.fc_id_2048_1 = nn.Linear(2048, num_classes)
self.fc_id_2048_2 = nn.Linear(2048, num_classes)
self.fc_id_256_1_0 = nn.Linear(256, num_classes)
self.fc_id_256_1_1 = nn.Linear(256, num_classes)
self.fc_id_256_2_0 = nn.Linear(256, num_classes)
self.fc_id_256_2_1 = nn.Linear(256, num_classes)
self.fc_id_256_2_2 = nn.Linear(256, num_classes)
self._init_fc(self.fc_id_2048_0)
self._init_fc(self.fc_id_2048_1)
self._init_fc(self.fc_id_2048_2)
self._init_fc(self.fc_id_256_1_0)
self._init_fc(self.fc_id_256_1_1)
self._init_fc(self.fc_id_256_2_0)
self._init_fc(self.fc_id_256_2_1)
self._init_fc(self.fc_id_256_2_2)
@staticmethod
def _init_reduction(reduction):
# conv
nn.init.kaiming_normal_(reduction[0].weight, mode='fan_in')
# nn.init.constant_(reduction[0].bias, 0.)
# bn
nn.init.normal_(reduction[1].weight, mean=1., std=0.02)
nn.init.constant_(reduction[1].bias, 0.)
@staticmethod
def _init_fc(fc):
# nn.init.kaiming_normal_(fc.weight, mode='fan_out')
nn.init.normal_(fc.weight, std=0.001)
nn.init.constant_(fc.bias, 0.)