Pytorch實現FPN及FCOS,附有詳細註釋!

FPN

在這裏插入圖片描述

class FPN(nn.Module):
    def __init__(self, block, layers):
        super(FPN, self).__init__()
        self.inplanes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)

        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        # Bottom-up layers
        self.layer1 = self._make_layer(block,  64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        # Top layer
        self.toplayer = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=0)  # Reduce channels

        # Smooth layers
        self.smooth1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.smooth2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.smooth3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)

        # Lateral layers
        self.latlayer1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
        self.latlayer2 = nn.Conv2d( 512, 256, kernel_size=1, stride=1, padding=0)
        self.latlayer3 = nn.Conv2d( 256, 256, kernel_size=1, stride=1, padding=0)
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample  = None
        if stride != 1 or self.inplanes != block.expansion * planes:
            downsample  = nn.Sequential(
                nn.Conv2d(self.inplanes, block.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(block.expansion * planes)
            )
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)


    def _upsample_add(self, x, y):
        _,_,H,W = y.size()
        return F.upsample(x, size=(H,W), mode='bilinear') + y

    def forward(self, x):
        # Bottom-up
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        c1 = self.maxpool(x)
        
        c2 = self.layer1(c1)
        c3 = self.layer2(c2)
        c4 = self.layer3(c3)
        c5 = self.layer4(c4)
        # Top-down
        p5 = self.toplayer(c5)
        p4 = self._upsample_add(p5, self.latlayer1(c4))
        p3 = self._upsample_add(p4, self.latlayer2(c3))
        p2 = self._upsample_add(p3, self.latlayer3(c2))
        # Smooth
        p4 = self.smooth1(p4)
        p3 = self.smooth2(p3)
        p2 = self.smooth3(p2)
        return p2, p3, p4, p5


FCOS

在這裏插入圖片描述

import torch
import torch.nn as nn
import torchvision
import torchvision.models as model
from torchsummary import summary
def Conv3x3ReLU(in_channels,out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=3,stride=1,padding=1),
        nn.ReLU6(inplace=True)
    )
# share head Between feature Levels

def locLayer(in_channels,out_channels):
    return nn.Sequential(
            Conv3x3ReLU(in_channels=in_channels, out_channels=in_channels),
            Conv3x3ReLU(in_channels=in_channels, out_channels=in_channels),
            Conv3x3ReLU(in_channels=in_channels, out_channels=in_channels),
            Conv3x3ReLU(in_channels=in_channels, out_channels=in_channels),
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
        )
# share head H*W*256
def conf_centernessLayer(in_channels,out_channels):
    return nn.Sequential(
        Conv3x3ReLU(in_channels=in_channels, out_channels=in_channels),
        Conv3x3ReLU(in_channels=in_channels, out_channels=in_channels),
        Conv3x3ReLU(in_channels=in_channels, out_channels=in_channels),
        Conv3x3ReLU(in_channels=in_channels, out_channels=in_channels),
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
    )

class FCOS(nn.Module):
    def __init__(self, num_classes=21):
        super(FCOS, self).__init__()
        self.num_classes = num_classes
        resnet = torchvision.models.resnet50()
        layers = list(resnet.children())
        '''
        Backbone + FPN
        
                  C5-C4-C3
            P7-P6-P5-P4-P3        
        '''
        self.layer1 = nn.Sequential(*layers[:5]) # c2
        self.layer2 = nn.Sequential(*layers[5])  # c3
        self.layer3 = nn.Sequential(*layers[6])  # c4
        self.layer4 = nn.Sequential(*layers[7])  # c5

        self.lateral5 = nn.Conv2d(in_channels=2048, out_channels=256, kernel_size=1) #c5
        self.lateral4 = nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=1) #c4
        self.lateral3 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1)  #c3

        '''
        以P5爲基礎,兩次上採樣,兩次下采樣
        '''
        self.upsample4 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=4, stride=2, padding=1)
        self.upsample3 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=4, stride=2, padding=1) #(h-1)*s -2*p +k +out_padding

        self.downsample6 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1)
        self.downsample5 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1)

        '''
        share head 
        
        regerssion
        classification + centerness
        '''
        self.loc_layer3 = locLayer(in_channels=256,out_channels=4)
        self.conf_centerness_layer3 = conf_centernessLayer(in_channels=256,out_channels=self.num_classes+1)

        self.loc_layer4 = locLayer(in_channels=256, out_channels=4)
        self.conf_centerness_layer4 = conf_centernessLayer(in_channels=256, out_channels=self.num_classes + 1)

        self.loc_layer5 = locLayer(in_channels=256, out_channels=4)
        self.conf_centerness_layer5 = conf_centernessLayer(in_channels=256, out_channels=self.num_classes + 1)

        self.loc_layer6 = locLayer(in_channels=256, out_channels=4)
        self.conf_centerness_layer6 = conf_centernessLayer(in_channels=256, out_channels=self.num_classes + 1)

        self.loc_layer7 = locLayer(in_channels=256, out_channels=4)
        self.conf_centerness_layer7 = conf_centernessLayer(in_channels=256, out_channels=self.num_classes + 1)

        self.init_params()

    def init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.layer1(x)
        c3 =x = self.layer2(x)
        c4 =x = self.layer3(x)
        c5 = x = self.layer4(x)

        p5 = self.lateral5(c5)
        p4 = self.upsample4(p5) + self.lateral4(c4)
        p3 = self.upsample3(p4) + self.lateral3(c3)

        p6 = self.downsample5(p5)
        p7 = self.downsample6(p6)

        '''
        P3-P7的特徵圖送入到後續網絡,進行分類與迴歸
        '''
        loc3 = self.loc_layer3(p3)
        conf_centerness3 = self.conf_centerness_layer3(p3)
        conf3, centerness3 = conf_centerness3.split([self.num_classes, 1], dim=1)

        loc4 = self.loc_layer4(p4)
        conf_centerness4 = self.conf_centerness_layer4(p4)
        conf4, centerness4 = conf_centerness4.split([self.num_classes, 1], dim=1)

        loc5 = self.loc_layer5(p5)
        conf_centerness5 = self.conf_centerness_layer5(p5)
        conf5, centerness5 = conf_centerness5.split([self.num_classes, 1], dim=1)

        loc6 = self.loc_layer6(p6)
        conf_centerness6 = self.conf_centerness_layer6(p6)
        conf6, centerness6 = conf_centerness6.split([self.num_classes, 1], dim=1)

        loc7 = self.loc_layer7(p7)
        conf_centerness7 = self.conf_centerness_layer7(p7)
        conf7, centerness7 = conf_centerness7.split([self.num_classes, 1], dim=1)

        '''
        多級預測
        '''
        locs = torch.cat([loc3.permute(0, 2, 3, 1).contiguous().view(loc3.size(0), -1),
                    loc4.permute(0, 2, 3, 1).contiguous().view(loc4.size(0), -1),
                    loc5.permute(0, 2, 3, 1).contiguous().view(loc5.size(0), -1),
                    loc6.permute(0, 2, 3, 1).contiguous().view(loc6.size(0), -1),
                    loc7.permute(0, 2, 3, 1).contiguous().view(loc7.size(0), -1)],dim=1)

        confs = torch.cat([conf3.permute(0, 2, 3, 1).contiguous().view(conf3.size(0), -1),
                           conf4.permute(0, 2, 3, 1).contiguous().view(conf4.size(0), -1),
                           conf5.permute(0, 2, 3, 1).contiguous().view(conf5.size(0), -1),
                           conf6.permute(0, 2, 3, 1).contiguous().view(conf6.size(0), -1),
                           conf7.permute(0, 2, 3, 1).contiguous().view(conf7.size(0), -1),], dim=1)

        centernesses = torch.cat([centerness3.permute(0, 2, 3, 1).contiguous().view(centerness3.size(0), -1),
                           centerness4.permute(0, 2, 3, 1).contiguous().view(centerness4.size(0), -1),
                           centerness5.permute(0, 2, 3, 1).contiguous().view(centerness5.size(0), -1),
                           centerness6.permute(0, 2, 3, 1).contiguous().view(centerness6.size(0), -1),
                           centerness7.permute(0, 2, 3, 1).contiguous().view(centerness7.size(0), -1), ], dim=1)

        out = (locs, confs, centernesses)
        return out

if __name__ == '__main__':
    # model = model.resnet50()
    # layers = list(model.children())
    # print(layers)
    # print(nn.Sequential(*layers[:5]))
    # print("layers[5]:",layers[5])
    model = FCOS()
    # print(model)

    input = torch.randn(3, 800, 1024)
    summary(model,input_size=(3, 800, 1024))
    # out = model(input)
    # print(out[0].shape)
    # print(out[1].shape)
    # print(out[2].shape)

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章