深度學習中,網絡層的初始化 Linerar層 BatchNormld(BN)層

模塊初始化

class IDE(nn.Module):

    def __init__(self, num_classes):
        super(IDE, self).__init__()

        resnet = resnet50(pretrained=True)
        self.backbone = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2,
            resnet.layer3,
            resnet.layer4,
            nn.AdaptiveAvgPool2d((1, 1)),
        )

        self.classifier = nn.Sequential(
            nn.Linear(2048, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.1),
            nn.Dropout(p=0.5),
            nn.Linear(512, num_classes)
        )

        nn.init.kaiming_normal_(self.classifier[0].weight, mode='fan_out')
        nn.init.constant_(self.classifier[0].bias, 0.)

        nn.init.normal_(self.classifier[1].weight, mean=1., std=0.02)
        nn.init.constant_(self.classifier[1].bias, 0.)

        nn.init.normal_(self.classifier[4].weight, std=0.001)
        nn.init.constant_(self.classifier[4].bias, 0.)

    def forward(self, x):
        """
        :param x: input image of (N, C, H, W)
        :return: (feature of N*2048, label predict of N*num_classes)
        """
        x = self.backbone(x)
        x = x.squeeze()

        y = self.classifier(x)
        return x, y


單層初始化

##########初始化FC層 ###################單層初始化
            # fc softmax loss
        self.fc_id_2048_0 = nn.Linear(2048, num_classes)
        self.fc_id_2048_1 = nn.Linear(2048, num_classes)
        self.fc_id_2048_2 = nn.Linear(2048, num_classes)
        self.fc_id_256_1_0 = nn.Linear(256, num_classes)
        self.fc_id_256_1_1 = nn.Linear(256, num_classes)
        self.fc_id_256_2_0 = nn.Linear(256, num_classes)
        self.fc_id_256_2_1 = nn.Linear(256, num_classes)
        self.fc_id_256_2_2 = nn.Linear(256, num_classes)
        self._init_fc(self.fc_id_2048_0)
        self._init_fc(self.fc_id_2048_1)
        self._init_fc(self.fc_id_2048_2)
        self._init_fc(self.fc_id_256_1_0)
        self._init_fc(self.fc_id_256_1_1)
        self._init_fc(self.fc_id_256_2_0)
        self._init_fc(self.fc_id_256_2_1)
        self._init_fc(self.fc_id_256_2_2)



	@staticmethod
    def _init_reduction(reduction):
        # conv
        nn.init.kaiming_normal_(reduction[0].weight, mode='fan_in')
        # nn.init.constant_(reduction[0].bias, 0.)

        # bn
        nn.init.normal_(reduction[1].weight, mean=1., std=0.02)
        nn.init.constant_(reduction[1].bias, 0.)
    
	@staticmethod
    def _init_fc(fc):
        # nn.init.kaiming_normal_(fc.weight, mode='fan_out')
        nn.init.normal_(fc.weight, std=0.001)
        nn.init.constant_(fc.bias, 0.)






發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章