Pytorch加載COCO預訓練DeepLabV3

DeeplabV3 ResNet101

Pytorch可以直接加載用COCO預訓練過的DeeplabV3模型,用於分割問題。模型在COCO train2017的一個子集上進行預訓練,訓練集包含20個Pascal VOC中的類別。

調用

對於ResNet101爲backbone的DeeplabV3,可以直接使用如下API調用:

torchvision.models.segmentation.deeplabv3_resnet101
(pretrained=False, progress=True, num_classes=21, aux_loss=None, **kwargs)

torchvision.models.segmentation源碼

API部分的源碼,定義網絡。源碼見pytorch官網

接口的定義函數deeplabv3_resnet101

主要參數爲:網絡的結構(fcn或deeplabv3),主幹網絡(resnet50或resnet101)。

def deeplabv3_resnet101(pretrained=False, progress=True,
                        num_classes=21, aux_loss=None, **kwargs):
    """Constructs a DeepLabV3 model with a ResNet-101 backbone.

    Args:
        pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
            contains the same classes as Pascal VOC
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _load_model('deeplabv3', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs)

加載模型的函數_load_model

def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs):
    if pretrained:
        aux_loss = True
    model = _segm_resnet(arch_type, backbone, num_classes, aux_loss, **kwargs)
    if pretrained:
        arch = arch_type + '_' + backbone + '_coco'
        model_url = model_urls[arch]
        if model_url is None:
            raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
        else:
            state_dict = load_state_dict_from_url(model_url, progress=progress)
            model.load_state_dict(state_dict)
    return model

創建用於分割的resnet函數_segm_resnet

def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True):
    backbone = resnet.__dict__[backbone_name](
        pretrained=pretrained_backbone,
        replace_stride_with_dilation=[False, True, True])

    return_layers = {'layer4': 'out'}
    if aux:
        return_layers['layer3'] = 'aux'
    backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)

    aux_classifier = None
    if aux:
        inplanes = 1024
        aux_classifier = FCNHead(inplanes, num_classes)

    model_map = {
        'deeplabv3': (DeepLabHead, DeepLabV3),
        'fcn': (FCNHead, FCN),
    }
    inplanes = 2048
    classifier = model_map[name][0](inplanes, num_classes)
    base_model = model_map[name][1]

    model = base_model(backbone, classifier, aux_classifier)
    return model

函數內部分別定義backbone和classifier,此外還提供了Inception和PSPNet中提到的輔助分割的接口aux。aux_classifier是從ResNet的layer3中提取特徵參與計算最終的loss。

舉例:定義主幹網絡爲ResNet-101的DeeplabV3

name = 'deeplabv3'
backbone_name = 'resnet101'
'''定義resnet101的backbone'''
backbone = resnet.__dict__[backbone_name]
model_map = {
        'deeplabv3': (DeepLabHead, DeepLabV3),
        'fcn': (FCNHead, FCN),
    }
'''從字典中選出deeplabv3對應的DeeplabHead作爲classifier,DeeplabV3作爲base_model'''
classifier = model_map[name][0](inplanes, num_classes)
base_model = model_map[name][1]
'''DeeplabV3的base_model,backbone爲resnet101,classifier爲DeeplabHead,aux_classifier爲FCNHead'''
model = base_model(backbone, classifier, aux_classifier)

torchvision.models源碼

上一級的源碼,定義上邊出現的各種具體的網絡結構。源碼見github

torchvision.models.resnet.py

常見的ResNet網絡結構的定義,完整源碼見

import torch
import torch.nn as nn
from .utils import load_state_dict_from_url

__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
           'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
           'wide_resnet50_2', 'wide_resnet101_2']

model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
class BasicBlock(nn.Module):

class Bottleneck(nn.Module):

class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        '''關於網絡結構定義'''
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):

    def forward(self, x):


def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)

def resnet18(pretrained=False, progress=True, **kwargs):

def resnet34(pretrained=False, progress=True, **kwargs):

def resnet50(pretrained=False, progress=True, **kwargs):

def resnet101(pretrained=False, progress=True, **kwargs):

def resnet152(pretrained=False, progress=True, **kwargs):

def resnext50_32x4d(pretrained=False, progress=True, **kwargs):

def resnext101_32x8d(pretrained=False, progress=True, **kwargs):

def wide_resnet50_2(pretrained=False, progress=True, **kwargs):

def wide_resnet101_2(pretrained=False, progress=True, **kwargs):

torchvision.models.segmentation.deeplabv3.py

DeeplabV3定義。

class DeepLabV3(_SimpleSegmentationModel):
    """
    Implements DeepLabV3 model from
    `"Rethinking Atrous Convolution for Semantic Image Segmentation"
    <https://arxiv.org/abs/1706.05587>`_.
    Arguments:
        backbone (nn.Module): the network used to compute the features for the model.
            The backbone should return an OrderedDict[Tensor], with the key being
            "out" for the last feature map used, and "aux" if an auxiliary classifier
            is used.
        classifier (nn.Module): module that takes the "out" element returned from
            the backbone and returns a dense prediction.
        aux_classifier (nn.Module, optional): auxiliary classifier used during training
    """
    pass

_SimpleSegmentationModel的定義

class _SimpleSegmentationModel(nn.Module):
    __constants__ = ['aux_classifier']

    def __init__(self, backbone, classifier, aux_classifier=None):
        super(_SimpleSegmentationModel, self).__init__()
        self.backbone = backbone
        self.classifier = classifier
        self.aux_classifier = aux_classifier

    def forward(self, x):
        input_shape = x.shape[-2:]
        # contract: features is a dict of tensors
        features = self.backbone(x)

        result = OrderedDict()
        x = features["out"] # 輸出特徵
        x = self.classifier(x) # 頭部分類器
        x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
        result["out"] = x # 上採樣恢復分辨率

        if self.aux_classifier is not None:
            x = features["aux"]
            x = self.aux_classifier(x)
            x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
            result["aux"] = x

        return result

DeepLabHead定義。

class DeepLabHead(nn.Sequential):
    def __init__(self, in_channels, num_classes):
        super(DeepLabHead, self).__init__(
            ASPP(in_channels, [12, 24, 36]),
            nn.Conv2d(256, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, num_classes, 1)
        )

deeplabv3.py還定義瞭如ASPP,ASPPConv,ASPPPoooling等具體的網絡結構。

注意

最後需要注意的是,使用pytorch官方提供的COCO預訓練模型,如果是使用他的網絡參數+自己寫的網絡,要注意加載網絡的時候,兩者的參數字典的鍵名要一致。兩種解決方法:

  1. 修改官方.pth文件的鍵名和自己寫的網絡一致。
  2. 修改自己的代碼,鍵名仿照官方的代碼起名,參考前邊的resnet.py部分。
發佈了25 篇原創文章 · 獲贊 18 · 訪問量 4萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章