加载COCO预训练DeeplabV3
DeeplabV3 ResNet101
Pytorch可以直接加载用COCO预训练过的DeeplabV3模型,用于分割问题。模型在COCO train2017的一个子集上进行预训练,训练集包含20个Pascal VOC中的类别。
调用
对于ResNet101为backbone的DeeplabV3,可以直接使用如下API调用:
torchvision.models.segmentation.deeplabv3_resnet101
(pretrained=False, progress=True, num_classes=21, aux_loss=None, **kwargs)
torchvision.models.segmentation源码
API部分的源码,定义网络。源码见pytorch官网
接口的定义函数deeplabv3_resnet101
主要参数为:网络的结构(fcn或deeplabv3),主干网络(resnet50或resnet101)。
def deeplabv3_resnet101(pretrained=False, progress=True,
num_classes=21, aux_loss=None, **kwargs):
"""Constructs a DeepLabV3 model with a ResNet-101 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _load_model('deeplabv3', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs)
加载模型的函数_load_model
def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs):
if pretrained:
aux_loss = True
model = _segm_resnet(arch_type, backbone, num_classes, aux_loss, **kwargs)
if pretrained:
arch = arch_type + '_' + backbone + '_coco'
model_url = model_urls[arch]
if model_url is None:
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
else:
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
return model
创建用于分割的resnet函数_segm_resnet
def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True):
backbone = resnet.__dict__[backbone_name](
pretrained=pretrained_backbone,
replace_stride_with_dilation=[False, True, True])
return_layers = {'layer4': 'out'}
if aux:
return_layers['layer3'] = 'aux'
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
aux_classifier = None
if aux:
inplanes = 1024
aux_classifier = FCNHead(inplanes, num_classes)
model_map = {
'deeplabv3': (DeepLabHead, DeepLabV3),
'fcn': (FCNHead, FCN),
}
inplanes = 2048
classifier = model_map[name][0](inplanes, num_classes)
base_model = model_map[name][1]
model = base_model(backbone, classifier, aux_classifier)
return model
函数内部分别定义backbone和classifier,此外还提供了Inception和PSPNet中提到的辅助分割的接口aux。aux_classifier是从ResNet的layer3中提取特征参与计算最终的loss。
举例:定义主干网络为ResNet-101的DeeplabV3
name = 'deeplabv3'
backbone_name = 'resnet101'
'''定义resnet101的backbone'''
backbone = resnet.__dict__[backbone_name]
model_map = {
'deeplabv3': (DeepLabHead, DeepLabV3),
'fcn': (FCNHead, FCN),
}
'''从字典中选出deeplabv3对应的DeeplabHead作为classifier,DeeplabV3作为base_model'''
classifier = model_map[name][0](inplanes, num_classes)
base_model = model_map[name][1]
'''DeeplabV3的base_model,backbone为resnet101,classifier为DeeplabHead,aux_classifier为FCNHead'''
model = base_model(backbone, classifier, aux_classifier)
torchvision.models源码
上一级的源码,定义上边出现的各种具体的网络结构。源码见github
torchvision.models.resnet.py
常见的ResNet网络结构的定义,完整源码见
import torch
import torch.nn as nn
from .utils import load_state_dict_from_url
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
'wide_resnet50_2', 'wide_resnet101_2']
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
class BasicBlock(nn.Module):
class Bottleneck(nn.Module):
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
'''关于网络结构定义'''
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
def forward(self, x):
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
model = ResNet(block, layers, **kwargs)
def resnet18(pretrained=False, progress=True, **kwargs):
def resnet34(pretrained=False, progress=True, **kwargs):
def resnet50(pretrained=False, progress=True, **kwargs):
def resnet101(pretrained=False, progress=True, **kwargs):
def resnet152(pretrained=False, progress=True, **kwargs):
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
torchvision.models.segmentation.deeplabv3.py
DeeplabV3定义。
class DeepLabV3(_SimpleSegmentationModel):
"""
Implements DeepLabV3 model from
`"Rethinking Atrous Convolution for Semantic Image Segmentation"
<https://arxiv.org/abs/1706.05587>`_.
Arguments:
backbone (nn.Module): the network used to compute the features for the model.
The backbone should return an OrderedDict[Tensor], with the key being
"out" for the last feature map used, and "aux" if an auxiliary classifier
is used.
classifier (nn.Module): module that takes the "out" element returned from
the backbone and returns a dense prediction.
aux_classifier (nn.Module, optional): auxiliary classifier used during training
"""
pass
_SimpleSegmentationModel的定义
class _SimpleSegmentationModel(nn.Module):
__constants__ = ['aux_classifier']
def __init__(self, backbone, classifier, aux_classifier=None):
super(_SimpleSegmentationModel, self).__init__()
self.backbone = backbone
self.classifier = classifier
self.aux_classifier = aux_classifier
def forward(self, x):
input_shape = x.shape[-2:]
# contract: features is a dict of tensors
features = self.backbone(x)
result = OrderedDict()
x = features["out"] # 输出特征
x = self.classifier(x) # 头部分类器
x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
result["out"] = x # 上采样恢复分辨率
if self.aux_classifier is not None:
x = features["aux"]
x = self.aux_classifier(x)
x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
result["aux"] = x
return result
DeepLabHead定义。
class DeepLabHead(nn.Sequential):
def __init__(self, in_channels, num_classes):
super(DeepLabHead, self).__init__(
ASPP(in_channels, [12, 24, 36]),
nn.Conv2d(256, 256, 3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256, num_classes, 1)
)
deeplabv3.py还定义了如ASPP,ASPPConv,ASPPPoooling等具体的网络结构。
注意
最后需要注意的是,使用pytorch官方提供的COCO预训练模型,如果是使用他的网络参数+自己写的网络,要注意加载网络的时候,两者的参数字典的键名要一致。两种解决方法:
- 修改官方.pth文件的键名和自己写的网络一致。
- 修改自己的代码,键名仿照官方的代码起名,参考前边的resnet.py部分。