深度學習論文:Learning Spatial Fusion for Single-Shot Object Detection及其PyTorch實現

Learning Spatial Fusion for Single-Shot Object Detection
PDF: https://arxiv.org/pdf/1911.09516.pdf
PyTorch代碼: https://github.com/shanglianlm0525/PyTorch-Networks

1 概述

本文提出了一種新的數據驅動的自適應空間特徵融合(ASFF)金字塔特徵融合方式, 通過學習空間上的過濾衝突信息以抑制梯度反傳的時候不一致的方法,從而改善了特徵的比例不變性, 進而提高目標檢測性能
在這裏插入圖片描述

2 自適應特徵融合(ASFF)

在這裏插入圖片描述

2-1 特徵尺寸調整(Feature Resizing)

對於需要上採樣的層,如想得到ASFF3,需要將level1的特徵圖調整到和level3的特徵圖尺寸一致,採用的方式是先通過1 x 1卷積調整到和level3通道數一致,再用插值的方式將尺寸調整到一致。而對於需要下采樣的層,比如想得到ASFF1,對於level2的特徵圖到level1的特徵圖只需要用一個3 x 3並且步長爲2的卷積就OK了,如果是level3的特徵圖到level1的特徵圖則需要在3 x 3卷積的基礎上再加上一個步長爲2的最大池化層。

2-2 自適應融合(Adaptive Fusion)

在這裏插入圖片描述

權重參數α\alphaβ\betaγ\gamma,則是通過resize後的level1~level3特徵圖經過 1 x 1卷積獲得的。並且參數α\alphaβ\betaγ\gamma,之後通過softmax使得它們的範圍都在[0,1]並且和爲1

3 實驗對比

3-1 與concat, elewise_sum 對比

在這裏插入圖片描述

3-2 加入其他目標檢測增強策略

在這裏插入圖片描述
[43] Zhi Zhang, Tong He, Hang Zhang, Zhongyuan Zhang, Junyuan Xie, and Mu Li. Bag of freebies for training object detection neural networks. arXiv preprint arXiv:1902.04103, 2019.
點擊查看
[38] Jiaqi Wang, Kai Chen, Shuo Yang, Chen Change Loy, and Dahua Lin. Region proposal by guided anchoring. In CVPR, 2019.
[41] Jiahui Yu, Yuning Jiang, Zhangyang Wang, Zhimin Cao, and Thomas Huang. Unitbox: An advanced object detection network. In ACMM, 2016.

4 ASFF可視化

在這裏插入圖片描述
PyTorch代碼:

import torch
import torch.nn as nn
import torchvision

def Conv1x1BnRelu(in_channels,out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU6(inplace=True),
    )

def upSampling1(in_channels,out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=1,stride=1,padding=0,bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU6(inplace=True),
        nn.Upsample(scale_factor=2, mode='nearest')
    )

def upSampling2(in_channels,out_channels):
    return nn.Sequential(
        upSampling1(in_channels,out_channels),
        nn.Upsample(scale_factor=2, mode='nearest'),
    )

def downSampling1(in_channels,out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU6(inplace=True),
    )

def downSampling2(in_channels,out_channels):
    return nn.Sequential(
        nn.MaxPool2d(kernel_size=3, stride=2,padding=1),
        downSampling1(in_channels=in_channels, out_channels=out_channels),
    )

class ASFF(nn.Module):
    def __init__(self, level, channel1, channel2, channel3, out_channel):
        super(ASFF, self).__init__()
        self.level = level
        funsed_channel = 8

        if self.level == 1:
            # level = 1:
            self.level2_1 = downSampling1(channel2,channel1)
            self.level3_1 = downSampling2(channel3,channel1)

            self.weight1 = Conv1x1BnRelu(channel1, funsed_channel)
            self.weight2 = Conv1x1BnRelu(channel1, funsed_channel)
            self.weight3 = Conv1x1BnRelu(channel1, funsed_channel)

            self.expand_conv = Conv1x1BnRelu(channel1,out_channel)

        if self.level == 2:
            #  level = 2:
            self.level1_2 = upSampling1(channel1,channel2)
            self.level3_2 = downSampling1(channel3,channel2)

            self.weight1 = Conv1x1BnRelu(channel2, funsed_channel)
            self.weight2 = Conv1x1BnRelu(channel2, funsed_channel)
            self.weight3 = Conv1x1BnRelu(channel2, funsed_channel)

            self.expand_conv = Conv1x1BnRelu(channel2, out_channel)

        if self.level == 3:
            #  level = 3:
            self.level1_3 = upSampling2(channel1,channel3)
            self.level2_3 = upSampling1(channel2,channel3)

            self.weight1 = Conv1x1BnRelu(channel3, funsed_channel)
            self.weight2 = Conv1x1BnRelu(channel3, funsed_channel)
            self.weight3 = Conv1x1BnRelu(channel3, funsed_channel)

            self.expand_conv = Conv1x1BnRelu(channel3, out_channel)

        self.weight_level = nn.Conv2d(funsed_channel * 3, 3, kernel_size=1, stride=1, padding=0)

        self.softmax = nn.Softmax(dim=1)


    def forward(self, x, y, z):
        if self.level == 1:
            level_x = x
            level_y = self.level2_1(y)
            level_z = self.level3_1(z)

        if self.level == 2:
            level_x = self.level1_2(x)
            level_y = y
            level_z = self.level3_2(z)

        if self.level == 3:
            level_x = self.level1_3(x)
            level_y = self.level2_3(y)
            level_z = z

        weight1 = self.weight1(level_x)
        weight2 = self.weight2(level_y)
        weight3 = self.weight3(level_z)

        level_weight = torch.cat((weight1, weight2, weight3), 1)
        weight_level = self.weight_level(level_weight)
        weight_level = self.softmax(weight_level)

        fused_level = level_x * weight_level[:,0,:,:] + level_y * weight_level[:,1,:,:] + level_z * weight_level[:,2,:,:]
        out = self.expand_conv(fused_level)
        return out

if __name__ == '__main__':
    model = ASFF(level=3, channel1=512, channel2=256, channel3=128, out_channel=128)
    print(model)

    x = torch.randn(1, 512, 16, 16)
    y = torch.randn(1, 256, 32, 32)
    z = torch.randn(1, 128, 64, 64)
    out = model(x,y,z)
    print(out.shape)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章