【庖丁解牛】從零實現RetinaNet(三):FPN、heads、Anchor、整體網絡結構、切換backbone

所有代碼已上傳到本人github repository:https://github.com/zgcr/pytorch-ImageNet-CIFAR-COCO-VOC-training
如果覺得有用,請點個star喲!
下列代碼均在pytorch1.4版本中測試過,確認正確無誤。

FPN

FPN的思想最早來源於這篇文章:https://arxiv.org/pdf/1612.03144.pdf 。在RetinaNet中,使用了5級FPN,包括由ResNet backbone stage 2/3/4 輸出的3個feature map經過融合得到的P3、P4、P5三級FPN feature map和使用stage4輸出繼續下采樣得到了P6和P7兩級FPN feature map。
整個FPN融合的過程如下:

C5-----3x3 conv downsample-----P6 out-----relu+3x3 conv downsample-----P7 out
|                                                              
|1x1 conv reduce channel to 256
|
P5-----upsample---------------------
|								   |
|3x3 conv                          |
|                                  |
P5 out                             |
								   |
C4                                 |
|                                  |
|1x1 conv reduce channel to 256    |
|                                  |
P4-----element-wise add-------------                
|
P4-----upsample---------------------
|                                  |
|3x3 conv                          |
|                                  |
P4 out                             |
                                   |
C3                                 |
|                                  |
|1x1 conv reduce channel to 256    |
|                                  |
P3-----element-wise add-------------     
|
|3x3 conv
|
P3 out

最後得到P3 out,P4 out,P5out,P6 out,P7 out五個輸出。
FPN代碼如下:

import torch
import torch.nn as nn
import torch.nn.functional as F


class FPN(nn.Module):
    def __init__(self, C3_inplanes, C4_inplanes, C5_inplanes, planes):
        super(FPN, self).__init__()
        self.P3_1 = nn.Conv2d(C3_inplanes,
                              planes,
                              kernel_size=1,
                              stride=1,
                              padding=0)
        self.P3_2 = nn.Conv2d(planes,
                              planes,
                              kernel_size=3,
                              stride=1,
                              padding=1)
        self.P4_1 = nn.Conv2d(C4_inplanes,
                              planes,
                              kernel_size=1,
                              stride=1,
                              padding=0)
        self.P4_2 = nn.Conv2d(planes,
                              planes,
                              kernel_size=3,
                              stride=1,
                              padding=1)
        self.P5_1 = nn.Conv2d(C5_inplanes,
                              planes,
                              kernel_size=1,
                              stride=1,
                              padding=0)
        self.P5_2 = nn.Conv2d(planes,
                              planes,
                              kernel_size=3,
                              stride=1,
                              padding=1)
        self.P6 = nn.Conv2d(C5_inplanes,
                            planes,
                            kernel_size=3,
                            stride=2,
                            padding=1)

        self.P7 = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(planes, planes, kernel_size=3, stride=2, padding=1))

    def forward(self, inputs):
        [C3, C4, C5] = inputs

        P5 = self.P5_1(C5)
        P4 = self.P4_1(C4)
        P4 = F.interpolate(P5, size=(P4.shape[2], P4.shape[3]),
                           mode='nearest') + P4
        P3 = self.P3_1(C3)
        P3 = F.interpolate(P4, size=(P3.shape[2], P3.shape[3]),
                           mode='nearest') + P3

        P6 = self.P6(C5)
        P7 = self.P7(P6)

        P5 = self.P5_2(P5)
        P4 = self.P4_2(P4)
        P3 = self.P3_2(P3)

        del C3, C4, C5

        return [P3, P4, P5, P6, P7]

注意上採樣必須使用F.interpolate(feature, size=(h,w), mode=‘nearest’),因爲backbone每次下采樣時feature map不一定正好縮減爲原來的一半(可能邊長多1或少1),只有使用這個api才能保證在輸入是任意尺寸的情況下本級feature map上採樣後的feature map尺寸和上一級feature map尺寸能夠吻合。

heads

在目標檢測中,直接參與loss計算的樣本的含義與分類任務中有所不同。在分類任務中,我們把每張圖片看成一個樣本,每張圖有一個label;在類似RetinaNet這樣含有Anchor的目標檢測器中,每一個Anchor就是一個樣本,每一個Anchor有一個label。
RetinaNet含有兩個heads:分類和迴歸。兩個heads的前半部分都是4次3x3 conv+relu。對於分類head,最後再接一個channel數爲num_anchors x num_classes的3x3 conv;對於迴歸head,最後再接一個channel數爲num_anchors x 4的3x3 conv。這個num_classes就是目標檢測數據集中目標的類別數。num_anchors爲所有級別FPN feature map的anchor數量之和。
heads代碼如下:

import torch
import torch.nn as nn
import torch.nn.functional as F


class ClsHead(nn.Module):
    def __init__(self,
                 inplanes,
                 num_anchors,
                 num_classes,
                 num_layers=4,
                 prior=0.01):
        super(ClsHead, self).__init__()
        layers = []
        for _ in range(num_layers):
            layers.append(
                nn.Conv2d(inplanes,
                          inplanes,
                          kernel_size=3,
                          stride=1,
                          padding=1))
            layers.append(nn.ReLU(inplace=True))
        layers.append(
            nn.Conv2d(inplanes,
                      num_anchors * num_classes,
                      kernel_size=3,
                      stride=1,
                      padding=1))
        self.cls_head = nn.Sequential(*layers)
        self.sigmoid = nn.Sigmoid()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, std=0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, val=0)

        prior = prior
        b = -math.log((1 - prior) / prior)
        self.cls_head[-1].bias.data.fill_(b)

    def forward(self, x):
        x = self.cls_head(x)
        x = self.sigmoid(x)

        return x


class RegHead(nn.Module):
    def __init__(self, inplanes, num_anchors, num_layers=4):
        super(RegHead, self).__init__()
        layers = []
        for _ in range(num_layers):
            layers.append(
                nn.Conv2d(inplanes,
                          inplanes,
                          kernel_size=3,
                          stride=1,
                          padding=1))
            layers.append(nn.ReLU(inplace=True))
        layers.append(
            nn.Conv2d(inplanes,
                      num_anchors * 4,
                      kernel_size=3,
                      stride=1,
                      padding=1))

        self.reg_head = nn.Sequential(*layers)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, std=0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, val=0)

    def forward(self, x):
        x = self.reg_head(x)

        return x

注意上面分類head和迴歸head的初始化略有不同,這個初始化實現完全與RetinaNet論文4.1中所述一致。

Anchor

Anchor最早是在faster rcnn(https://arxiv.org/pdf/1506.01497.pdf)中提出。所謂Anchor,就是一組不同長寬比、不同大小的先驗框。在RetinaNet中,通過backbone和FPN我們能夠得到5級FPN feature map。假如原始圖片大小爲h,w=640,640,那麼P3,P4,P5,P6,P7 5級 feature map的h,w分別爲[80, 80],[40, 40],[20, 20],[10, 10],[5, 5]。對於feature map上的每一個點,映射回原圖就是一個網格。因此我們一共得到8525個網格。RetinaNet在每個網格的中心點上都放置了9個不同長寬比、不同大小的先驗框。因此Anchor的總數量爲76725。需要注意的是每個網格上的9個先驗框長寬都是一樣的,只是框的中心點不同。
RetinaNet使用了三種長寬比和三种放大比例先生成了9種長寬組合:

ratios = [0.5, 1, 2]
scales = [2**0, 2**(1.0 / 3.0), 2**(2.0 / 3.0)]
aspects = [[[s * math.sqrt(r), s * math.sqrt(1 / r)] for s in scales]
           for r in ratios]
# aspects
# [[[0.7071067811865476, 1.4142135623730951], [0.8908987181403394, 1.7817974362806788], [1.122462048309373, 2.244924096618746]], [[1.0, 1.0], [1.2599210498948732, 1.2599210498948732], [1.5874010519681994, 1.5874010519681994]], [[1.4142135623730951, 0.7071067811865476], [1.7817974362806788, 0.8908987181403394], [2.244924096618746, 1.122462048309373]]]

然後對於5個層級的FPN feature map,使用了5個基礎長度乘以上面的長寬組合。基礎長度:

[32, 32], [64, 64], [128, 128], [256, 256], [512, 512]

這樣我們就得到了所有Anchor的中心點和長寬。最後,我們將Anchor座標形式變爲[x_min,y_min,x_max,y_max],即框的左上角和右下角座標。

Anchor代碼如下:

import math
import numpy as np
import torch
import torch.nn as nn


class RetinaAnchors(nn.Module):
    def __init__(self, areas, ratios, scales, strides):
        super(RetinaAnchors, self).__init__()
        self.areas = areas
        self.ratios = ratios
        self.scales = scales
        self.strides = strides

    def forward(self, batch_size, fpn_feature_sizes):
        """
        generate batch anchors
        """
        device = fpn_feature_sizes.device
        one_sample_anchors = []
        for index, area in enumerate(self.areas):
            base_anchors = self.generate_base_anchors(area, self.scales,
                                                      self.ratios)
            featrue_anchors = self.generate_anchors_on_feature_map(
                base_anchors, fpn_feature_sizes[index], self.strides[index])
            featrue_anchors = featrue_anchors.to(device)
            one_sample_anchors.append(featrue_anchors)

        batch_anchors = []
        for per_level_featrue_anchors in one_sample_anchors:
            per_level_featrue_anchors = per_level_featrue_anchors.unsqueeze(
                0).repeat(batch_size, 1, 1)
            batch_anchors.append(per_level_featrue_anchors)

        # if input size:[B,3,640,640]
        # batch_anchors shape:[[B, 57600, 4],[B, 14400, 4],[B, 3600, 4],[B, 900, 4],[B, 225, 4]]
        # per anchor format:[x_min,y_min,x_max,y_max]
        return batch_anchors

    def generate_base_anchors(self, area, scales, ratios):
        """
        generate base anchor
        """
        # get w,h aspect ratio,shape:[9,2]
        aspects = torch.tensor([[[s * math.sqrt(r), s * math.sqrt(1 / r)]
                                 for s in scales]
                                for r in ratios]).view(-1, 2)
        # base anchor for each position on feature map,shape[9,4]
        base_anchors = torch.zeros((len(scales) * len(ratios), 4))

        # compute aspect w\h,shape[9,2]
        base_w_h = area * aspects
        base_anchors[:, 2:] += base_w_h

        # base_anchors format: [x_min,y_min,x_max,y_max],center point:[0,0],shape[9,4]
        base_anchors[:, 0] -= base_anchors[:, 2] / 2
        base_anchors[:, 1] -= base_anchors[:, 3] / 2
        base_anchors[:, 2] /= 2
        base_anchors[:, 3] /= 2

        return base_anchors

    def generate_anchors_on_feature_map(self, base_anchors, feature_map_size,
                                        stride):
        """
        generate all anchors on a feature map
        """
        # shifts_x shape:[w],shifts_x shape:[h]
        shifts_x = (torch.arange(0, feature_map_size[0]) + 0.5) * stride
        shifts_y = (torch.arange(0, feature_map_size[1]) + 0.5) * stride

        # shifts shape:[w,h,2] -> [w,h,4] -> [w,h,1,4]
        shifts = torch.tensor([[[shift_x, shift_y] for shift_y in shifts_y]
                               for shift_x in shifts_x]).repeat(1, 1,
                                                                2).unsqueeze(2)

        # base anchors shape:[9,4] -> [1,1,9,4]
        base_anchors = base_anchors.unsqueeze(0).unsqueeze(0)
        # generate all featrue map anchors on each feature map points
        # featrue map anchors shape:[w,h,9,4] -> [h,w,9,4] -> [h*w*9,4]
        feature_map_anchors = (base_anchors + shifts).permute(
            1, 0, 2, 3).contiguous().view(-1, 4)

        # feature_map_anchors format: [anchor_nums,4],4:[x_min,y_min,x_max,y_max]
        return feature_map_anchors

整體網絡結構

RetinaNet使用ResNet網絡作爲backbone。這裏我們使用Pytorch官方實現的resnet網絡結構,由於最後需要輸出stage2/3/4 的feature map,所以需要對網絡結構重新定義一下:

class ResNetBackbone(nn.Module):
    def __init__(self, resnet_type="resnet50"):
        super(ResNetBackbone, self).__init__()
        self.model = models.__dict__[resnet_type](**{"pretrained": True})
        del self.model.fc
        del self.model.avgpool

    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)

        x = self.model.layer1(x)
        C3 = self.model.layer2(x)
        C4 = self.model.layer3(C3)
        C5 = self.model.layer4(C4)

        del x

        return [C3, C4, C5]

然後我們就可以得到完整的RetinaNet網絡結構:

import os
import sys
import numpy as np

BASE_DIR = os.path.dirname(
    os.path.dirname(os.path.dirname(os.path.dirname(
        os.path.abspath(__file__)))))
sys.path.append(BASE_DIR)

from public.path import pretrained_models_path

from public.detection.models.backbone import ResNetBackbone
from public.detection.models.fpn import FPN
from public.detection.models.head import ClsHead, RegHead
from public.detection.models.anchor import RetinaAnchors

import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = [
    'resnet18_retinanet',
    'resnet34_retinanet',
    'resnet50_retinanet',
    'resnet101_retinanet',
    'resnet152_retinanet',
]

model_urls = {
    'resnet18_retinanet':
    'empty',
    'resnet34_retinanet':
    'empty',
    'resnet50_retinanet':
    'empty',
    'resnet101_retinanet':
    'empty',
    'resnet152_retinanet':
    'empty',
}


# assert input annotations are[x_min,y_min,x_max,y_max]
class RetinaNet(nn.Module):
    def __init__(self, resnet_type, num_anchors=9, num_classes=80, planes=256):
        super(RetinaNet, self).__init__()
        self.backbone = ResNetBackbone(resnet_type=resnet_type)
        expand_ratio = {
            "resnet18": 1,
            "resnet34": 1,
            "resnet50": 4,
            "resnet101": 4,
            "resnet152": 4
        }
        C3_inplanes, C4_inplanes, C5_inplanes = int(
            128 * expand_ratio[resnet_type]), int(
                256 * expand_ratio[resnet_type]), int(
                    512 * expand_ratio[resnet_type])
        self.fpn = FPN(C3_inplanes, C4_inplanes, C5_inplanes, planes)

        self.num_anchors = num_anchors
        self.num_classes = num_classes
        self.planes = planes

        self.cls_head = ClsHead(self.planes,
                                self.num_anchors,
                                self.num_classes,
                                num_layers=4,
                                prior=0.01)

        self.reg_head = RegHead(self.planes, self.num_anchors, num_layers=4)

        self.areas = torch.tensor([[32, 32], [64, 64], [128, 128], [256, 256],
                                   [512, 512]])
        self.ratios = torch.tensor([0.5, 1, 2])
        self.scales = torch.tensor([2**0, 2**(1.0 / 3.0), 2**(2.0 / 3.0)])
        self.strides = torch.tensor([8, 16, 32, 64, 128], dtype=torch.float)

        self.anchors = RetinaAnchors(self.areas, self.ratios, self.scales,
                                     self.strides)

    def forward(self, inputs):
        self.batch_size, _, _, _ = inputs.shape
        device = inputs.device

        [C3, C4, C5] = self.backbone(inputs)

        del inputs

        features = self.fpn([C3, C4, C5])

        del C3, C4, C5

        self.fpn_feature_sizes = []
        cls_heads, reg_heads = [], []
        for feature in features:
            self.fpn_feature_sizes.append([feature.shape[3], feature.shape[2]])
            cls_head = self.cls_head(feature)
            # [N,9*num_classes,H,W] -> [N,H*W*9,num_classes]
            cls_head = cls_head.permute(0, 2, 3, 1).contiguous().view(
                self.batch_size, -1, self.num_classes)
            cls_heads.append(cls_head)

            reg_head = self.reg_head(feature)
            # [N, 9*4,H,W] -> [N,H*W*9, 4]
            reg_head = reg_head.permute(0, 2, 3, 1).contiguous().view(
                self.batch_size, -1, 4)
            reg_heads.append(reg_head)

        del features

        self.fpn_feature_sizes = torch.tensor(
            self.fpn_feature_sizes).to(device)

        # if input size:[B,3,640,640]
        # features shape:[[B, 256, 80, 80],[B, 256, 40, 40],[B, 256, 20, 20],[B, 256, 10, 10],[B, 256, 5, 5]]
        # cls_heads shape:[[B, 57600, 80],[B, 14400, 80],[B, 3600, 80],[B, 900, 80],[B, 225, 80]]
        # reg_heads shape:[[B, 57600, 4],[B, 14400, 4],[B, 3600, 4],[B, 900, 4],[B, 225, 4]]
        # batch_anchors shape:[[B, 57600, 4],[B, 14400, 4],[B, 3600, 4],[B, 900, 4],[B, 225, 4]]

        batch_anchors = self.anchors(self.batch_size, self.fpn_feature_sizes)

        return cls_heads, reg_heads, batch_anchors


def _retinanet(arch, pretrained, progress, **kwargs):
    model = RetinaNet(arch, **kwargs)
    # only load state_dict()
    if pretrained:
        model.load_state_dict(
            torch.load(model_urls[arch + "_retinanet"],
                       map_location=torch.device('cpu')))

    return model


def resnet18_retinanet(pretrained=False, progress=True, **kwargs):
    return _retinanet('resnet18', pretrained, progress, **kwargs)


def resnet34_retinanet(pretrained=False, progress=True, **kwargs):
    return _retinanet('resnet34', pretrained, progress, **kwargs)


def resnet50_retinanet(pretrained=False, progress=True, **kwargs):
    return _retinanet('resnet50', pretrained, progress, **kwargs)


def resnet101_retinanet(pretrained=False, progress=True, **kwargs):
    return _retinanet('resnet101', pretrained, progress, **kwargs)


def resnet152_retinanet(pretrained=False, progress=True, **kwargs):
    return _retinanet('resnet152', pretrained, progress, **kwargs)

通過上面的代碼,我們實現了一個以ResNet family爲backbone的RetinaNet family。

切換backbone

要切換上面實現的RetinaNet網絡的backbone非常簡單,只需要把ResNetBackbone修改爲你新定義的backbone,然後把C3_inplanes/C4_inplanes/C5_inplanes的值設爲新的backbone輸出的C3/C4/C5輸出的feature map的channel數即可。
我們定義了在base model第四彈中實現的DarkNet backbone和VovNet backbone(想了解這兩個網絡的結構請查看我的base model第四彈文章):

class Darknet19Backbone(nn.Module):
    def __init__(self):
        super(Darknet19Backbone, self).__init__()
        self.model = models.__dict__['darknet19'](**{"pretrained": True})
        del self.model.avgpool
        del self.model.layer7

    def forward(self, x):
        x = self.model.layer1(x)
        x = self.model.maxpool1(x)
        x = self.model.layer2(x)
        C3 = self.model.layer3(x)
        C4 = self.model.layer4(C3)
        C5 = self.model.layer5(C4)
        C5 = self.model.layer6(C5)

        del x

        return [C3, C4, C5]


class Darknet53Backbone(nn.Module):
    def __init__(self):
        super(Darknet53Backbone, self).__init__()
        self.model = models.__dict__['darknet53'](**{"pretrained": True})
        del self.model.fc
        del self.model.avgpool

    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.conv2(x)
        x = self.model.block1(x)
        x = self.model.conv3(x)
        x = self.model.block2(x)
        x = self.model.conv4(x)
        C3 = self.model.block3(x)
        C4 = self.model.conv5(C3)
        C4 = self.model.block4(C4)
        C5 = self.model.conv6(C4)
        C5 = self.model.block5(C5)

        del x

        return [C3, C4, C5]
        
class VovNetBackbone(nn.Module):
    def __init__(self, vovnet_type='VoVNet39_se'):
        super(VovNetBackbone, self).__init__()
        self.model = models.__dict__[vovnet_type](**{"pretrained": True})
        del self.model.fc
        del self.model.avgpool

    def forward(self, x):
        x = self.model.stem(x)

        features = []
        for stage in self.model.stages:
            x = stage(x)
            features.append(x)

        del x

        return features[1:]

由於self.fpn_feature_sizes直接獲取了FPN每個層級feature map的H和W,因此不會出現Anchor數量與pred head的shape對不上的情況。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章