CV06-Xception筆記

目錄

一、爲啥是Xception

二、Xception結構

2.1 Xception結構基本描述

2.2 實現細節

2.3 DeepLabV3+改進

三、記錄pytorch採坑relu激活函數inplace=True


Xception筆記,記錄一些自己認爲重要的要點,以免日後遺忘。

復現Xception論文、DeepLabV+改進的Xception,代碼地址https://github.com/Ascetics/Pytorch-Xception

一、爲啥是Xception

Xception脫胎於Inception,Inception的思想是將卷積分成cross-channel conv和spatial conv,更準確的說是先用1x1卷積得到幾個不同channel(小於輸入channel)的結果,再在這些結果上分別用3x3、5x5 conv,也就是論文Figure 1描述的那樣。Inception的這種算法背後,本質上是將cross-channel conv和spatial conv解耦。

考慮將Inception簡化:去掉平均池化層,只用3x3 conv(2個3x3 conv相當於1個5x5 conv)。就測到了論文Figure 2描述的這種結構。

 

在Figure 2的基礎上,用1個channel很大的1x1 conv 將輸入映射到一個channel很大的輸出上。再將這個輸出“切成幾段”,“切成幾段”分別做3x3 spatial conv,就得到了論文中Figure 3的結構。作者在此提出一個問題,這樣將cross-channel conv和spatial conv完全解耦分開合理嗎?完全解耦分開,可以這樣做嗎?
基於Figure 3提出的假設,做一個極端的Inception模型。還是先用1個channel很大的1x1 conv 將輸入映射到一個channel很大的輸出上,然後“切成幾段”變成“切片”,每個channel切一片。對每個channel做3x3卷積。這樣極端的設計就接近於深度可分離卷積depthwise separable convolution。

爲什麼是“接近”,而不是“就是”呢?因爲和depthwise separable convolution的操作順序、操作內容不一樣。

  1. 順序上,depthwise separable convolution,用3x3 conv進行spatial conv,用1x1 conv進行cross-channel conv;極端版本Inception先用1x1 conv再用3x3 conv;
  2. 內容上,depthwise separable convolution,spatial conv和cross-channel conv之間沒有非線性(ReLU激活函數);極端版本Inception,卷積之間有非線性(ReLU激活函數);

作者認爲第一個區別是不重要的,特別是因爲這些操作要在堆疊(深度學習)的環境中使用。第二個區別重要,作者研究了一下,結論見論文Figure 10。本文後面會解釋。

要看懂Xception,需要了解VGG、Inception、Depthwise Separable Convlution和ResNet,都會用到。

二、Xception結構

2.1 Xception結構基本描述

卷積神經網絡特徵提取中的卷積都可以完全解耦,變成深度可分離卷積(Xception也就是Extreme Inception的意思)。接收了這一設定,Xception結構被解釋爲論文Figure 5的樣子。

Xception的特徵提取基礎由36個conv layer構成。這36個conv layer被組織成14個module,除了第一個和最後一個module,其餘的module都帶有residual connection(殘差,參看何凱明大神的ResNet)。簡言之,Xception結構就是連續使用depthwise separable convolution layer和residual connection。

2.2 實現細節

如Figure 5 描述所述。

輸入先經過Entry flow,不重複;再經過Middle flow,Middle flow重複8次;最後經過Exit flow,不重複。

所有的Conv 和 Separable Conv後面都加BN層,但是Figure 5沒有畫出來。

所有的Separable Conv都用depth=1,也就是每個depth-wise都是“切片”的。

注意, depthwise separable convolution在spatial conv和cross-channel conv之間不要加ReLU激活函數,任何激活函數都不要加。論文Figure 10展示了,這裏不加激活函數效果最好,加ReLU、ELU都不好。

 還有一些是論文中沒有明說的細節。

Residual Connection在1x1卷積後面也加上BN。Residual Connection加上以後,不要着急做激活函數,仔細看圖,激活函數ReLU是屬於下一個Block的。這就導致了代碼實現上採坑,下一節詳細記錄一下。也算長個記性。

2.3 DeepLabV3+改進

這一部分在下一篇博客,學DeepLabV3+中再記錄。

三、記錄pytorch採坑relu激活函數inplace=True

上面2.2寫了一個細節,如果嚴格按照論文的示意圖來實現Xception,那麼每個block第一個操作不是SeparableConv,而是ReLU(紅色框)。如果僅僅是第一個操作是ReLU也沒有關係,但是旁邊還有個Residual Connection(藍色框)。自古紅藍出CP,於是坑來了,在反向傳播的時候,報了個錯,明顯是因爲inplace導致了某個對象被modify了,反向傳播求梯度報錯。(此時我所有代碼的ReLU用的都是inplace=True,省內存嘛)

以Entry Flow的Block爲例,先來欣賞一下錯誤的代碼。

class _PoolEntryBlock(nn.Module):
    def __init__(self, in_channels, out_channels, relu1=True):
        """
        Entry Flow的3個下采樣module
        按論文所說,每個Conv和Separable Conv都需要跟BN
        論文Figure 5中,第1個Separable Conv前面沒有ReLU,需要判斷一下,
        論文Figure 5中,每個module的Separable Conv的out_channels一樣,MaxPool做下采樣
        :param in_channels: 輸入channels
        :param out_channels: 輸出channels
        :param relu1: 判斷有沒有第一個ReLU,默認是有的
        """
        super(_PoolEntryBlock, self).__init__()
        self.project = ResidualConnection(in_channels, out_channels, stride=2)
        self.relu1 = None
        if relu1:
            self.relu1 = nn.ReLU(inplace=True)  
        self.sepconv1 = SeparableConv2d(in_channels, out_channels,
                                        kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)

        self.relu2 = nn.ReLU(inplace=True)
        self.sepconv2 = SeparableConv2d(out_channels, out_channels,
                                        kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
        pass

    def forward(self, x):
        identity = self.project(x)  # residual connection 準備

        if self.relu1:  # 第1個Separable Conv前面沒有ReLU,需要判斷一下
            x = self.relu1(x)
        x = self.sepconv1(x)  # 第2個Separable Conv
        x = self.bn1(x)

        x = self.relu2(x)
        x = self.sepconv2(x)  # 第2個Separable Conv
        x = self.bn2(x)

        x = self.maxpool(x)  # 下采樣2倍

        x = x + identity  # residual connection 相加
        return x

    pass

採坑時的做法,先算出Residual Connection,再做relu、SeparableConv。Residual Connection時,已經進行過一次卷積操作,此時要求輸入x本身不能發生改變,不能再被modify。後面的ReLU(inplace=True)恰恰就modify了x。所以反向傳播時報錯。

改變執行的先後順序呢?也不行。如果先ReLU(inplace=True),那麼x也被modify了,再做Residual Connection時輸入就不是block輸入的那個x了。

解決的辦法,改爲ReLU(inplace=False),或者Residual Connection的輸入改爲x.clone(),總之不能省內存……正確的代碼已經push到github上了,地址詳見文章開頭。

爲此,我寫了一個簡化的模型:

  1. class Wrong就是採坑的錯誤實現;
  2. class RightOne就是改爲ReLU(inplace=False);
  3. class RightTwo就是Residual Connection的輸入改爲x.clone();

一杯茶,一包煙,一個bug改一天……

import torch
import torch.nn as nn


class Wrong(nn.Module):
    def __init__(self):
        super(Wrong, self).__init__()
        self.convs = nn.Sequential(nn.ReLU(inplace=True),
                                   nn.Conv2d(3, 3, 3, padding=1))
        self.residual = nn.Conv2d(3, 3, 3, padding=1)
        pass

    def forward(self, x):
        r = self.residual(x)  # 卷積之後,x就不能modify了
        h = self.convs(x)  # relu就modify了x,反向傳播時候會報錯
        h = h + r
        return h

    pass


class RightOne(nn.Module):
    def __init__(self):
        super(RightOne, self).__init__()
        self.convs = nn.Sequential(nn.ReLU(inplace=False),  # 改法1,別省內存了
                                   nn.Conv2d(3, 3, 3, padding=1))
        self.residual = nn.Conv2d(3, 3, 3, padding=1)
        pass

    def forward(self, x):
        r = self.residual(x)
        h = self.convs(x)
        h = h + r
        return h

    pass


class RightTwo(nn.Module):
    def __init__(self):
        super(RightTwo, self).__init__()
        self.convs = nn.Sequential(nn.ReLU(inplace=True),
                                   nn.Conv2d(3, 3, 3, padding=1))
        self.residual = nn.Conv2d(3, 3, 3, padding=1)
        pass

    def forward(self, x):
        r = self.residual(x.clone())  # 改法2,clone還是消耗內存的
        h = self.convs(x)
        h = h + r
        return h

    pass


if __name__ == '__main__':
    in_data = torch.randint(-2, 2, (1, 3, 2, 2), dtype=torch.float)
    in_label = torch.randint(0, 3, (1, 2, 2))

    print(in_data.shape)

    func = nn.CrossEntropyLoss()
    t = RightTwo()

    in_data = in_data.cuda()
    in_label = in_label.cuda()
    t.cuda()
    out_data = t(in_data)
    print(out_data.shape)

    loss = func(out_data, in_label)
    loss.backward()

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章