不定期讀一篇Paper之GC-Net

不定期讀一篇Paper之GC-Net

前言

誰能經得起更多的質疑,誰才更值得相信。

即使經過很多質疑建立起來的理論,出現了新的問題,仍然可以質疑。

質疑是最基本的思考。

質疑本身也要經得起質疑。

​                — 北京大學物理學院副教授、北京大學高性能計算平臺主任

​                 — 雷奕安

​   本篇論文主要針對non local中對所有位置進行查詢,提出質疑,作者自己通過嚴格的實驗之後,發現在non local中對於所有位置查詢的到attention map,其實是一樣的,所以,作者認爲只需要計算一張注意力圖就可以了,並且,還增加了SE模塊來使網絡兼顧通道間的依賴,此外,SE模塊還可以減少少量的參數。當然,non local經過簡化、添加之後,其預測的效果並沒有丟失。實驗中,non local網絡中處於不同位置查詢的可視化效果如下圖:

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-p1k4pYwE-1586764174647)(./imgs/visiual.png)]

框架

​   論文中把GC-block拆解爲三部分,以下論各個部分的論述:

a: a context modeling module which aggregates the features of all positions together to form a global context feature; (只獲取一個位置的注意力權重)

b: a feature transform module to capture the channel-wise interdependencies; (SE本質)

c: a fusion module to merge the global context feature into features of all positions.(使用廣播機制)

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-6O4BZqDD-1586764174652)(./imgs/frame.png)]

Pytorch代碼

​   注意張量的維度,可以使用tensorboard可視化之後結合代碼理解。

class ContextBlock(nn.Module):
    def __init__(self,inplanes,ratio,pooling_type='att',
                 fusion_types=('channel_add', )):
        super(ContextBlock, self).__init__()
        valid_fusion_types = ['channel_add', 'channel_mul']

        assert pooling_type in ['avg', 'att']
        assert isinstance(fusion_types, (list, tuple))
        assert all([f in valid_fusion_types for f in fusion_types])
        assert len(fusion_types) > 0, 'at least one fusion should be used'

        self.inplanes = inplanes
        self.ratio = ratio
        self.planes = int(inplanes * ratio)
        self.pooling_type = pooling_type
        self.fusion_types = fusion_types
		
        ## gc module
        if pooling_type == 'att':
            self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
            self.softmax = nn.Softmax(dim=2)
        else:
            self.avg_pool = nn.AdaptiveAvgPool2d(1)
        
        ##  SE 模塊
        if 'channel_add' in fusion_types:
            self.channel_add_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
        else:
            self.channel_add_conv = None
            
        if 'channel_mul' in fusion_types:
            self.channel_mul_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
        else:
            self.channel_mul_conv = None


    def spatial_pool(self, x):
        batch, channel, height, width = x.size()
        if self.pooling_type == 'att':
            ## 左路分支
            input_x = x
            # [N, C, H * W]
            input_x = input_x.view(batch, channel, height * width)
            # [N, 1, C, H * W]
            input_x = input_x.unsqueeze(1)
            
            ## 右路分支
            # [N, 1, H, W]
            context_mask = self.conv_mask(x)
            # [N, 1, H * W]
            context_mask = context_mask.view(batch, 1, height * width)
            # [N, 1, H * W]
            context_mask = self.softmax(context_mask)
            # [N, 1, H * W, 1]
            context_mask = context_mask.unsqueeze(-1)
            
            # 獲取全局attention
            # [N, 1, C, 1]
            context = torch.matmul(input_x, context_mask)
            # [N, C, 1, 1]
            context = context.view(batch, channel, 1, 1)
        else:
            # [N, C, 1, 1]
            context = self.avg_pool(x)
        return context

    def forward(self, x):
        # [N, C, 1, 1]
        context = self.spatial_pool(x)
        # [N, C, H, W]
        out = x
        if self.channel_mul_conv is not None:
            # [N, C, 1, 1]
            channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
            # broadcast機制
            out = out * channel_mul_term
        if self.channel_add_conv is not None:
            # [N, C, 1, 1]
            channel_add_term = self.channel_add_conv(context)
            out = out + channel_add_term
        return out

可視化框架:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [1, 1, 128, 128]              65
           Softmax-2              [1, 1, 16384]               0
            Conv2d-3               [1, 4, 1, 1]             260
         LayerNorm-4               [1, 4, 1, 1]               8
              ReLU-5               [1, 4, 1, 1]               0
            Conv2d-6              [1, 64, 1, 1]             320

實驗

本次實驗是在lenet網絡上進行,分別對原始lenet,gc版lenet和non local版lenet進行了對比實驗,此次實驗只要驗證自己的一些想法,併爲做參數調節,所有實驗都是在同一參數、數據和環境下進行實驗(未考慮率不同模型下的網絡優化迭代次數,統一設置epoch爲300)由於,gc和non local其模塊化的本質,所以可以很好的嵌入任何卷積網絡來提升效果,此次實驗訓練效果如下:

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-eeTLClNt-1586764174654)(./imgs/train_loss.svg)] [外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-IIWtB2Nu-1586764174655)(./imgs/train_acc.svg)]
圖1 loss 圖2 accuracy

​   從圖中可以觀察到,non local在訓練的過程中不穩定,波動較大,而gclenet收斂平穩,個人閱讀non local論文時,也注意到non local論文中的論述更傾向與視頻分類,當然,作者也用實驗證明其可以用在分類網絡中。除此之外,網絡的訓練過程中gclenet訓練時間相較於non local更短一些,這與gcnet中使用簡化的Non Local模塊和SE模塊有關,官方論證如下圖:

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-AidCP25j-1586764174661)(./imgs/param.png)]
  當然,此次實驗中仍有很多不之處,後期如果有時間會補上。😄(主要被困家中,設備欠缺)

結論

​   自己認爲這是一篇不錯的paper,可以精讀幾次,學習其中一些構思性的東西。

參考

論文地址

官方代碼

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章