【庖丁解牛】从零实现RetinaNet(三):FPN、heads、Anchor、整体网络结构、切换backbone

所有代码已上传到本人github repository:https://github.com/zgcr/pytorch-ImageNet-CIFAR-COCO-VOC-training
如果觉得有用,请点个star哟!
下列代码均在pytorch1.4版本中测试过,确认正确无误。

FPN

FPN的思想最早来源于这篇文章:https://arxiv.org/pdf/1612.03144.pdf 。在RetinaNet中,使用了5级FPN,包括由ResNet backbone stage 2/3/4 输出的3个feature map经过融合得到的P3、P4、P5三级FPN feature map和使用stage4输出继续下采样得到了P6和P7两级FPN feature map。
整个FPN融合的过程如下:

C5-----3x3 conv downsample-----P6 out-----relu+3x3 conv downsample-----P7 out
|                                                              
|1x1 conv reduce channel to 256
|
P5-----upsample---------------------
|								   |
|3x3 conv                          |
|                                  |
P5 out                             |
								   |
C4                                 |
|                                  |
|1x1 conv reduce channel to 256    |
|                                  |
P4-----element-wise add-------------                
|
P4-----upsample---------------------
|                                  |
|3x3 conv                          |
|                                  |
P4 out                             |
                                   |
C3                                 |
|                                  |
|1x1 conv reduce channel to 256    |
|                                  |
P3-----element-wise add-------------     
|
|3x3 conv
|
P3 out

最后得到P3 out,P4 out,P5out,P6 out,P7 out五个输出。
FPN代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F


class FPN(nn.Module):
    def __init__(self, C3_inplanes, C4_inplanes, C5_inplanes, planes):
        super(FPN, self).__init__()
        self.P3_1 = nn.Conv2d(C3_inplanes,
                              planes,
                              kernel_size=1,
                              stride=1,
                              padding=0)
        self.P3_2 = nn.Conv2d(planes,
                              planes,
                              kernel_size=3,
                              stride=1,
                              padding=1)
        self.P4_1 = nn.Conv2d(C4_inplanes,
                              planes,
                              kernel_size=1,
                              stride=1,
                              padding=0)
        self.P4_2 = nn.Conv2d(planes,
                              planes,
                              kernel_size=3,
                              stride=1,
                              padding=1)
        self.P5_1 = nn.Conv2d(C5_inplanes,
                              planes,
                              kernel_size=1,
                              stride=1,
                              padding=0)
        self.P5_2 = nn.Conv2d(planes,
                              planes,
                              kernel_size=3,
                              stride=1,
                              padding=1)
        self.P6 = nn.Conv2d(C5_inplanes,
                            planes,
                            kernel_size=3,
                            stride=2,
                            padding=1)

        self.P7 = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(planes, planes, kernel_size=3, stride=2, padding=1))

    def forward(self, inputs):
        [C3, C4, C5] = inputs

        P5 = self.P5_1(C5)
        P4 = self.P4_1(C4)
        P4 = F.interpolate(P5, size=(P4.shape[2], P4.shape[3]),
                           mode='nearest') + P4
        P3 = self.P3_1(C3)
        P3 = F.interpolate(P4, size=(P3.shape[2], P3.shape[3]),
                           mode='nearest') + P3

        P6 = self.P6(C5)
        P7 = self.P7(P6)

        P5 = self.P5_2(P5)
        P4 = self.P4_2(P4)
        P3 = self.P3_2(P3)

        del C3, C4, C5

        return [P3, P4, P5, P6, P7]

注意上采样必须使用F.interpolate(feature, size=(h,w), mode=‘nearest’),因为backbone每次下采样时feature map不一定正好缩减为原来的一半(可能边长多1或少1),只有使用这个api才能保证在输入是任意尺寸的情况下本级feature map上采样后的feature map尺寸和上一级feature map尺寸能够吻合。

heads

在目标检测中,直接参与loss计算的样本的含义与分类任务中有所不同。在分类任务中,我们把每张图片看成一个样本,每张图有一个label;在类似RetinaNet这样含有Anchor的目标检测器中,每一个Anchor就是一个样本,每一个Anchor有一个label。
RetinaNet含有两个heads:分类和回归。两个heads的前半部分都是4次3x3 conv+relu。对于分类head,最后再接一个channel数为num_anchors x num_classes的3x3 conv;对于回归head,最后再接一个channel数为num_anchors x 4的3x3 conv。这个num_classes就是目标检测数据集中目标的类别数。num_anchors为所有级别FPN feature map的anchor数量之和。
heads代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F


class ClsHead(nn.Module):
    def __init__(self,
                 inplanes,
                 num_anchors,
                 num_classes,
                 num_layers=4,
                 prior=0.01):
        super(ClsHead, self).__init__()
        layers = []
        for _ in range(num_layers):
            layers.append(
                nn.Conv2d(inplanes,
                          inplanes,
                          kernel_size=3,
                          stride=1,
                          padding=1))
            layers.append(nn.ReLU(inplace=True))
        layers.append(
            nn.Conv2d(inplanes,
                      num_anchors * num_classes,
                      kernel_size=3,
                      stride=1,
                      padding=1))
        self.cls_head = nn.Sequential(*layers)
        self.sigmoid = nn.Sigmoid()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, std=0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, val=0)

        prior = prior
        b = -math.log((1 - prior) / prior)
        self.cls_head[-1].bias.data.fill_(b)

    def forward(self, x):
        x = self.cls_head(x)
        x = self.sigmoid(x)

        return x


class RegHead(nn.Module):
    def __init__(self, inplanes, num_anchors, num_layers=4):
        super(RegHead, self).__init__()
        layers = []
        for _ in range(num_layers):
            layers.append(
                nn.Conv2d(inplanes,
                          inplanes,
                          kernel_size=3,
                          stride=1,
                          padding=1))
            layers.append(nn.ReLU(inplace=True))
        layers.append(
            nn.Conv2d(inplanes,
                      num_anchors * 4,
                      kernel_size=3,
                      stride=1,
                      padding=1))

        self.reg_head = nn.Sequential(*layers)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, std=0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, val=0)

    def forward(self, x):
        x = self.reg_head(x)

        return x

注意上面分类head和回归head的初始化略有不同,这个初始化实现完全与RetinaNet论文4.1中所述一致。

Anchor

Anchor最早是在faster rcnn(https://arxiv.org/pdf/1506.01497.pdf)中提出。所谓Anchor,就是一组不同长宽比、不同大小的先验框。在RetinaNet中,通过backbone和FPN我们能够得到5级FPN feature map。假如原始图片大小为h,w=640,640,那么P3,P4,P5,P6,P7 5级 feature map的h,w分别为[80, 80],[40, 40],[20, 20],[10, 10],[5, 5]。对于feature map上的每一个点,映射回原图就是一个网格。因此我们一共得到8525个网格。RetinaNet在每个网格的中心点上都放置了9个不同长宽比、不同大小的先验框。因此Anchor的总数量为76725。需要注意的是每个网格上的9个先验框长宽都是一样的,只是框的中心点不同。
RetinaNet使用了三种长宽比和三种放大比例先生成了9种长宽组合:

ratios = [0.5, 1, 2]
scales = [2**0, 2**(1.0 / 3.0), 2**(2.0 / 3.0)]
aspects = [[[s * math.sqrt(r), s * math.sqrt(1 / r)] for s in scales]
           for r in ratios]
# aspects
# [[[0.7071067811865476, 1.4142135623730951], [0.8908987181403394, 1.7817974362806788], [1.122462048309373, 2.244924096618746]], [[1.0, 1.0], [1.2599210498948732, 1.2599210498948732], [1.5874010519681994, 1.5874010519681994]], [[1.4142135623730951, 0.7071067811865476], [1.7817974362806788, 0.8908987181403394], [2.244924096618746, 1.122462048309373]]]

然后对于5个层级的FPN feature map,使用了5个基础长度乘以上面的长宽组合。基础长度:

[32, 32], [64, 64], [128, 128], [256, 256], [512, 512]

这样我们就得到了所有Anchor的中心点和长宽。最后,我们将Anchor座标形式变为[x_min,y_min,x_max,y_max],即框的左上角和右下角座标。

Anchor代码如下:

import math
import numpy as np
import torch
import torch.nn as nn


class RetinaAnchors(nn.Module):
    def __init__(self, areas, ratios, scales, strides):
        super(RetinaAnchors, self).__init__()
        self.areas = areas
        self.ratios = ratios
        self.scales = scales
        self.strides = strides

    def forward(self, batch_size, fpn_feature_sizes):
        """
        generate batch anchors
        """
        device = fpn_feature_sizes.device
        one_sample_anchors = []
        for index, area in enumerate(self.areas):
            base_anchors = self.generate_base_anchors(area, self.scales,
                                                      self.ratios)
            featrue_anchors = self.generate_anchors_on_feature_map(
                base_anchors, fpn_feature_sizes[index], self.strides[index])
            featrue_anchors = featrue_anchors.to(device)
            one_sample_anchors.append(featrue_anchors)

        batch_anchors = []
        for per_level_featrue_anchors in one_sample_anchors:
            per_level_featrue_anchors = per_level_featrue_anchors.unsqueeze(
                0).repeat(batch_size, 1, 1)
            batch_anchors.append(per_level_featrue_anchors)

        # if input size:[B,3,640,640]
        # batch_anchors shape:[[B, 57600, 4],[B, 14400, 4],[B, 3600, 4],[B, 900, 4],[B, 225, 4]]
        # per anchor format:[x_min,y_min,x_max,y_max]
        return batch_anchors

    def generate_base_anchors(self, area, scales, ratios):
        """
        generate base anchor
        """
        # get w,h aspect ratio,shape:[9,2]
        aspects = torch.tensor([[[s * math.sqrt(r), s * math.sqrt(1 / r)]
                                 for s in scales]
                                for r in ratios]).view(-1, 2)
        # base anchor for each position on feature map,shape[9,4]
        base_anchors = torch.zeros((len(scales) * len(ratios), 4))

        # compute aspect w\h,shape[9,2]
        base_w_h = area * aspects
        base_anchors[:, 2:] += base_w_h

        # base_anchors format: [x_min,y_min,x_max,y_max],center point:[0,0],shape[9,4]
        base_anchors[:, 0] -= base_anchors[:, 2] / 2
        base_anchors[:, 1] -= base_anchors[:, 3] / 2
        base_anchors[:, 2] /= 2
        base_anchors[:, 3] /= 2

        return base_anchors

    def generate_anchors_on_feature_map(self, base_anchors, feature_map_size,
                                        stride):
        """
        generate all anchors on a feature map
        """
        # shifts_x shape:[w],shifts_x shape:[h]
        shifts_x = (torch.arange(0, feature_map_size[0]) + 0.5) * stride
        shifts_y = (torch.arange(0, feature_map_size[1]) + 0.5) * stride

        # shifts shape:[w,h,2] -> [w,h,4] -> [w,h,1,4]
        shifts = torch.tensor([[[shift_x, shift_y] for shift_y in shifts_y]
                               for shift_x in shifts_x]).repeat(1, 1,
                                                                2).unsqueeze(2)

        # base anchors shape:[9,4] -> [1,1,9,4]
        base_anchors = base_anchors.unsqueeze(0).unsqueeze(0)
        # generate all featrue map anchors on each feature map points
        # featrue map anchors shape:[w,h,9,4] -> [h,w,9,4] -> [h*w*9,4]
        feature_map_anchors = (base_anchors + shifts).permute(
            1, 0, 2, 3).contiguous().view(-1, 4)

        # feature_map_anchors format: [anchor_nums,4],4:[x_min,y_min,x_max,y_max]
        return feature_map_anchors

整体网络结构

RetinaNet使用ResNet网络作为backbone。这里我们使用Pytorch官方实现的resnet网络结构,由于最后需要输出stage2/3/4 的feature map,所以需要对网络结构重新定义一下:

class ResNetBackbone(nn.Module):
    def __init__(self, resnet_type="resnet50"):
        super(ResNetBackbone, self).__init__()
        self.model = models.__dict__[resnet_type](**{"pretrained": True})
        del self.model.fc
        del self.model.avgpool

    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)

        x = self.model.layer1(x)
        C3 = self.model.layer2(x)
        C4 = self.model.layer3(C3)
        C5 = self.model.layer4(C4)

        del x

        return [C3, C4, C5]

然后我们就可以得到完整的RetinaNet网络结构:

import os
import sys
import numpy as np

BASE_DIR = os.path.dirname(
    os.path.dirname(os.path.dirname(os.path.dirname(
        os.path.abspath(__file__)))))
sys.path.append(BASE_DIR)

from public.path import pretrained_models_path

from public.detection.models.backbone import ResNetBackbone
from public.detection.models.fpn import FPN
from public.detection.models.head import ClsHead, RegHead
from public.detection.models.anchor import RetinaAnchors

import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = [
    'resnet18_retinanet',
    'resnet34_retinanet',
    'resnet50_retinanet',
    'resnet101_retinanet',
    'resnet152_retinanet',
]

model_urls = {
    'resnet18_retinanet':
    'empty',
    'resnet34_retinanet':
    'empty',
    'resnet50_retinanet':
    'empty',
    'resnet101_retinanet':
    'empty',
    'resnet152_retinanet':
    'empty',
}


# assert input annotations are[x_min,y_min,x_max,y_max]
class RetinaNet(nn.Module):
    def __init__(self, resnet_type, num_anchors=9, num_classes=80, planes=256):
        super(RetinaNet, self).__init__()
        self.backbone = ResNetBackbone(resnet_type=resnet_type)
        expand_ratio = {
            "resnet18": 1,
            "resnet34": 1,
            "resnet50": 4,
            "resnet101": 4,
            "resnet152": 4
        }
        C3_inplanes, C4_inplanes, C5_inplanes = int(
            128 * expand_ratio[resnet_type]), int(
                256 * expand_ratio[resnet_type]), int(
                    512 * expand_ratio[resnet_type])
        self.fpn = FPN(C3_inplanes, C4_inplanes, C5_inplanes, planes)

        self.num_anchors = num_anchors
        self.num_classes = num_classes
        self.planes = planes

        self.cls_head = ClsHead(self.planes,
                                self.num_anchors,
                                self.num_classes,
                                num_layers=4,
                                prior=0.01)

        self.reg_head = RegHead(self.planes, self.num_anchors, num_layers=4)

        self.areas = torch.tensor([[32, 32], [64, 64], [128, 128], [256, 256],
                                   [512, 512]])
        self.ratios = torch.tensor([0.5, 1, 2])
        self.scales = torch.tensor([2**0, 2**(1.0 / 3.0), 2**(2.0 / 3.0)])
        self.strides = torch.tensor([8, 16, 32, 64, 128], dtype=torch.float)

        self.anchors = RetinaAnchors(self.areas, self.ratios, self.scales,
                                     self.strides)

    def forward(self, inputs):
        self.batch_size, _, _, _ = inputs.shape
        device = inputs.device

        [C3, C4, C5] = self.backbone(inputs)

        del inputs

        features = self.fpn([C3, C4, C5])

        del C3, C4, C5

        self.fpn_feature_sizes = []
        cls_heads, reg_heads = [], []
        for feature in features:
            self.fpn_feature_sizes.append([feature.shape[3], feature.shape[2]])
            cls_head = self.cls_head(feature)
            # [N,9*num_classes,H,W] -> [N,H*W*9,num_classes]
            cls_head = cls_head.permute(0, 2, 3, 1).contiguous().view(
                self.batch_size, -1, self.num_classes)
            cls_heads.append(cls_head)

            reg_head = self.reg_head(feature)
            # [N, 9*4,H,W] -> [N,H*W*9, 4]
            reg_head = reg_head.permute(0, 2, 3, 1).contiguous().view(
                self.batch_size, -1, 4)
            reg_heads.append(reg_head)

        del features

        self.fpn_feature_sizes = torch.tensor(
            self.fpn_feature_sizes).to(device)

        # if input size:[B,3,640,640]
        # features shape:[[B, 256, 80, 80],[B, 256, 40, 40],[B, 256, 20, 20],[B, 256, 10, 10],[B, 256, 5, 5]]
        # cls_heads shape:[[B, 57600, 80],[B, 14400, 80],[B, 3600, 80],[B, 900, 80],[B, 225, 80]]
        # reg_heads shape:[[B, 57600, 4],[B, 14400, 4],[B, 3600, 4],[B, 900, 4],[B, 225, 4]]
        # batch_anchors shape:[[B, 57600, 4],[B, 14400, 4],[B, 3600, 4],[B, 900, 4],[B, 225, 4]]

        batch_anchors = self.anchors(self.batch_size, self.fpn_feature_sizes)

        return cls_heads, reg_heads, batch_anchors


def _retinanet(arch, pretrained, progress, **kwargs):
    model = RetinaNet(arch, **kwargs)
    # only load state_dict()
    if pretrained:
        model.load_state_dict(
            torch.load(model_urls[arch + "_retinanet"],
                       map_location=torch.device('cpu')))

    return model


def resnet18_retinanet(pretrained=False, progress=True, **kwargs):
    return _retinanet('resnet18', pretrained, progress, **kwargs)


def resnet34_retinanet(pretrained=False, progress=True, **kwargs):
    return _retinanet('resnet34', pretrained, progress, **kwargs)


def resnet50_retinanet(pretrained=False, progress=True, **kwargs):
    return _retinanet('resnet50', pretrained, progress, **kwargs)


def resnet101_retinanet(pretrained=False, progress=True, **kwargs):
    return _retinanet('resnet101', pretrained, progress, **kwargs)


def resnet152_retinanet(pretrained=False, progress=True, **kwargs):
    return _retinanet('resnet152', pretrained, progress, **kwargs)

通过上面的代码,我们实现了一个以ResNet family为backbone的RetinaNet family。

切换backbone

要切换上面实现的RetinaNet网络的backbone非常简单,只需要把ResNetBackbone修改为你新定义的backbone,然后把C3_inplanes/C4_inplanes/C5_inplanes的值设为新的backbone输出的C3/C4/C5输出的feature map的channel数即可。
我们定义了在base model第四弹中实现的DarkNet backbone和VovNet backbone(想了解这两个网络的结构请查看我的base model第四弹文章):

class Darknet19Backbone(nn.Module):
    def __init__(self):
        super(Darknet19Backbone, self).__init__()
        self.model = models.__dict__['darknet19'](**{"pretrained": True})
        del self.model.avgpool
        del self.model.layer7

    def forward(self, x):
        x = self.model.layer1(x)
        x = self.model.maxpool1(x)
        x = self.model.layer2(x)
        C3 = self.model.layer3(x)
        C4 = self.model.layer4(C3)
        C5 = self.model.layer5(C4)
        C5 = self.model.layer6(C5)

        del x

        return [C3, C4, C5]


class Darknet53Backbone(nn.Module):
    def __init__(self):
        super(Darknet53Backbone, self).__init__()
        self.model = models.__dict__['darknet53'](**{"pretrained": True})
        del self.model.fc
        del self.model.avgpool

    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.conv2(x)
        x = self.model.block1(x)
        x = self.model.conv3(x)
        x = self.model.block2(x)
        x = self.model.conv4(x)
        C3 = self.model.block3(x)
        C4 = self.model.conv5(C3)
        C4 = self.model.block4(C4)
        C5 = self.model.conv6(C4)
        C5 = self.model.block5(C5)

        del x

        return [C3, C4, C5]
        
class VovNetBackbone(nn.Module):
    def __init__(self, vovnet_type='VoVNet39_se'):
        super(VovNetBackbone, self).__init__()
        self.model = models.__dict__[vovnet_type](**{"pretrained": True})
        del self.model.fc
        del self.model.avgpool

    def forward(self, x):
        x = self.model.stem(x)

        features = []
        for stage in self.model.stages:
            x = stage(x)
            features.append(x)

        del x

        return features[1:]

由于self.fpn_feature_sizes直接获取了FPN每个层级feature map的H和W,因此不会出现Anchor数量与pred head的shape对不上的情况。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章