深度学习论文:Learning Spatial Fusion for Single-Shot Object Detection及其PyTorch实现

Learning Spatial Fusion for Single-Shot Object Detection
PDF: https://arxiv.org/pdf/1911.09516.pdf
PyTorch代码: https://github.com/shanglianlm0525/PyTorch-Networks

1 概述

本文提出了一种新的数据驱动的自适应空间特征融合(ASFF)金字塔特征融合方式, 通过学习空间上的过滤冲突信息以抑制梯度反传的时候不一致的方法,从而改善了特征的比例不变性, 进而提高目标检测性能
在这里插入图片描述

2 自适应特征融合(ASFF)

在这里插入图片描述

2-1 特征尺寸调整(Feature Resizing)

对于需要上采样的层,如想得到ASFF3,需要将level1的特征图调整到和level3的特征图尺寸一致,采用的方式是先通过1 x 1卷积调整到和level3通道数一致,再用插值的方式将尺寸调整到一致。而对于需要下采样的层,比如想得到ASFF1,对于level2的特征图到level1的特征图只需要用一个3 x 3并且步长为2的卷积就OK了,如果是level3的特征图到level1的特征图则需要在3 x 3卷积的基础上再加上一个步长为2的最大池化层。

2-2 自适应融合(Adaptive Fusion)

在这里插入图片描述

权重参数α\alphaβ\betaγ\gamma,则是通过resize后的level1~level3特征图经过 1 x 1卷积获得的。并且参数α\alphaβ\betaγ\gamma,之后通过softmax使得它们的范围都在[0,1]并且和为1

3 实验对比

3-1 与concat, elewise_sum 对比

在这里插入图片描述

3-2 加入其他目标检测增强策略

在这里插入图片描述
[43] Zhi Zhang, Tong He, Hang Zhang, Zhongyuan Zhang, Junyuan Xie, and Mu Li. Bag of freebies for training object detection neural networks. arXiv preprint arXiv:1902.04103, 2019.
点击查看
[38] Jiaqi Wang, Kai Chen, Shuo Yang, Chen Change Loy, and Dahua Lin. Region proposal by guided anchoring. In CVPR, 2019.
[41] Jiahui Yu, Yuning Jiang, Zhangyang Wang, Zhimin Cao, and Thomas Huang. Unitbox: An advanced object detection network. In ACMM, 2016.

4 ASFF可视化

在这里插入图片描述
PyTorch代码:

import torch
import torch.nn as nn
import torchvision

def Conv1x1BnRelu(in_channels,out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU6(inplace=True),
    )

def upSampling1(in_channels,out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=1,stride=1,padding=0,bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU6(inplace=True),
        nn.Upsample(scale_factor=2, mode='nearest')
    )

def upSampling2(in_channels,out_channels):
    return nn.Sequential(
        upSampling1(in_channels,out_channels),
        nn.Upsample(scale_factor=2, mode='nearest'),
    )

def downSampling1(in_channels,out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU6(inplace=True),
    )

def downSampling2(in_channels,out_channels):
    return nn.Sequential(
        nn.MaxPool2d(kernel_size=3, stride=2,padding=1),
        downSampling1(in_channels=in_channels, out_channels=out_channels),
    )

class ASFF(nn.Module):
    def __init__(self, level, channel1, channel2, channel3, out_channel):
        super(ASFF, self).__init__()
        self.level = level
        funsed_channel = 8

        if self.level == 1:
            # level = 1:
            self.level2_1 = downSampling1(channel2,channel1)
            self.level3_1 = downSampling2(channel3,channel1)

            self.weight1 = Conv1x1BnRelu(channel1, funsed_channel)
            self.weight2 = Conv1x1BnRelu(channel1, funsed_channel)
            self.weight3 = Conv1x1BnRelu(channel1, funsed_channel)

            self.expand_conv = Conv1x1BnRelu(channel1,out_channel)

        if self.level == 2:
            #  level = 2:
            self.level1_2 = upSampling1(channel1,channel2)
            self.level3_2 = downSampling1(channel3,channel2)

            self.weight1 = Conv1x1BnRelu(channel2, funsed_channel)
            self.weight2 = Conv1x1BnRelu(channel2, funsed_channel)
            self.weight3 = Conv1x1BnRelu(channel2, funsed_channel)

            self.expand_conv = Conv1x1BnRelu(channel2, out_channel)

        if self.level == 3:
            #  level = 3:
            self.level1_3 = upSampling2(channel1,channel3)
            self.level2_3 = upSampling1(channel2,channel3)

            self.weight1 = Conv1x1BnRelu(channel3, funsed_channel)
            self.weight2 = Conv1x1BnRelu(channel3, funsed_channel)
            self.weight3 = Conv1x1BnRelu(channel3, funsed_channel)

            self.expand_conv = Conv1x1BnRelu(channel3, out_channel)

        self.weight_level = nn.Conv2d(funsed_channel * 3, 3, kernel_size=1, stride=1, padding=0)

        self.softmax = nn.Softmax(dim=1)


    def forward(self, x, y, z):
        if self.level == 1:
            level_x = x
            level_y = self.level2_1(y)
            level_z = self.level3_1(z)

        if self.level == 2:
            level_x = self.level1_2(x)
            level_y = y
            level_z = self.level3_2(z)

        if self.level == 3:
            level_x = self.level1_3(x)
            level_y = self.level2_3(y)
            level_z = z

        weight1 = self.weight1(level_x)
        weight2 = self.weight2(level_y)
        weight3 = self.weight3(level_z)

        level_weight = torch.cat((weight1, weight2, weight3), 1)
        weight_level = self.weight_level(level_weight)
        weight_level = self.softmax(weight_level)

        fused_level = level_x * weight_level[:,0,:,:] + level_y * weight_level[:,1,:,:] + level_z * weight_level[:,2,:,:]
        out = self.expand_conv(fused_level)
        return out

if __name__ == '__main__':
    model = ASFF(level=3, channel1=512, channel2=256, channel3=128, out_channel=128)
    print(model)

    x = torch.randn(1, 512, 16, 16)
    y = torch.randn(1, 256, 32, 32)
    z = torch.randn(1, 128, 64, 64)
    out = model(x,y,z)
    print(out.shape)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章