目錄
三、記錄pytorch採坑relu激活函數inplace=True
Xception筆記,記錄一些自己認爲重要的要點,以免日後遺忘。
復現Xception論文、DeepLabV+改進的Xception,代碼地址https://github.com/Ascetics/Pytorch-Xception
一、爲啥是Xception
Xception脫胎於Inception,Inception的思想是將卷積分成cross-channel conv和spatial conv,更準確的說是先用1x1卷積得到幾個不同channel(小於輸入channel)的結果,再在這些結果上分別用3x3、5x5 conv,也就是論文Figure 1描述的那樣。Inception的這種算法背後,本質上是將cross-channel conv和spatial conv解耦。
考慮將Inception簡化:去掉平均池化層,只用3x3 conv(2個3x3 conv相當於1個5x5 conv)。就測到了論文Figure 2描述的這種結構。
在Figure 2的基礎上,用1個channel很大的1x1 conv 將輸入映射到一個channel很大的輸出上。再將這個輸出“切成幾段”,“切成幾段”分別做3x3 spatial conv,就得到了論文中Figure 3的結構。作者在此提出一個問題,這樣將cross-channel conv和spatial conv完全解耦分開合理嗎?完全解耦分開,可以這樣做嗎?
基於Figure 3提出的假設,做一個極端的Inception模型。還是先用1個channel很大的1x1 conv 將輸入映射到一個channel很大的輸出上,然後“切成幾段”變成“切片”,每個channel切一片。對每個channel做3x3卷積。這樣極端的設計就接近於深度可分離卷積depthwise separable convolution。
爲什麼是“接近”,而不是“就是”呢?因爲和depthwise separable convolution的操作順序、操作內容不一樣。
- 順序上,depthwise separable convolution,先用3x3 conv進行spatial conv,後用1x1 conv進行cross-channel conv;極端版本Inception先用1x1 conv再用3x3 conv;
- 內容上,depthwise separable convolution,spatial conv和cross-channel conv之間沒有非線性(ReLU激活函數);極端版本Inception,卷積之間有非線性(ReLU激活函數);
作者認爲第一個區別是不重要的,特別是因爲這些操作要在堆疊(深度學習)的環境中使用。第二個區別重要,作者研究了一下,結論見論文Figure 10。本文後面會解釋。
要看懂Xception,需要了解VGG、Inception、Depthwise Separable Convlution和ResNet,都會用到。
二、Xception結構
2.1 Xception結構基本描述
卷積神經網絡特徵提取中的卷積都可以完全解耦,變成深度可分離卷積(Xception也就是Extreme Inception的意思)。接收了這一設定,Xception結構被解釋爲論文Figure 5的樣子。
Xception的特徵提取基礎由36個conv layer構成。這36個conv layer被組織成14個module,除了第一個和最後一個module,其餘的module都帶有residual connection(殘差,參看何凱明大神的ResNet)。簡言之,Xception結構就是連續使用depthwise separable convolution layer和residual connection。
2.2 實現細節
如Figure 5 描述所述。
輸入先經過Entry flow,不重複;再經過Middle flow,Middle flow重複8次;最後經過Exit flow,不重複。
所有的Conv 和 Separable Conv後面都加BN層,但是Figure 5沒有畫出來。
所有的Separable Conv都用depth=1,也就是每個depth-wise都是“切片”的。
注意, depthwise separable convolution在spatial conv和cross-channel conv之間不要加ReLU激活函數,任何激活函數都不要加。論文Figure 10展示了,這裏不加激活函數效果最好,加ReLU、ELU都不好。
還有一些是論文中沒有明說的細節。
Residual Connection在1x1卷積後面也加上BN。Residual Connection加上以後,不要着急做激活函數,仔細看圖,激活函數ReLU是屬於下一個Block的。這就導致了代碼實現上採坑,下一節詳細記錄一下。也算長個記性。
2.3 DeepLabV3+改進
這一部分在下一篇博客,學DeepLabV3+中再記錄。
三、記錄pytorch採坑relu激活函數inplace=True
上面2.2寫了一個細節,如果嚴格按照論文的示意圖來實現Xception,那麼每個block第一個操作不是SeparableConv,而是ReLU(紅色框)。如果僅僅是第一個操作是ReLU也沒有關係,但是旁邊還有個Residual Connection(藍色框)。自古紅藍出CP,於是坑來了,在反向傳播的時候,報了個錯,明顯是因爲inplace導致了某個對象被modify了,反向傳播求梯度報錯。(此時我所有代碼的ReLU用的都是inplace=True,省內存嘛)
以Entry Flow的Block爲例,先來欣賞一下錯誤的代碼。
class _PoolEntryBlock(nn.Module):
def __init__(self, in_channels, out_channels, relu1=True):
"""
Entry Flow的3個下采樣module
按論文所說,每個Conv和Separable Conv都需要跟BN
論文Figure 5中,第1個Separable Conv前面沒有ReLU,需要判斷一下,
論文Figure 5中,每個module的Separable Conv的out_channels一樣,MaxPool做下采樣
:param in_channels: 輸入channels
:param out_channels: 輸出channels
:param relu1: 判斷有沒有第一個ReLU,默認是有的
"""
super(_PoolEntryBlock, self).__init__()
self.project = ResidualConnection(in_channels, out_channels, stride=2)
self.relu1 = None
if relu1:
self.relu1 = nn.ReLU(inplace=True)
self.sepconv1 = SeparableConv2d(in_channels, out_channels,
kernel_size=3, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu2 = nn.ReLU(inplace=True)
self.sepconv2 = SeparableConv2d(out_channels, out_channels,
kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
pass
def forward(self, x):
identity = self.project(x) # residual connection 準備
if self.relu1: # 第1個Separable Conv前面沒有ReLU,需要判斷一下
x = self.relu1(x)
x = self.sepconv1(x) # 第2個Separable Conv
x = self.bn1(x)
x = self.relu2(x)
x = self.sepconv2(x) # 第2個Separable Conv
x = self.bn2(x)
x = self.maxpool(x) # 下采樣2倍
x = x + identity # residual connection 相加
return x
pass
採坑時的做法,先算出Residual Connection,再做relu、SeparableConv。Residual Connection時,已經進行過一次卷積操作,此時要求輸入x本身不能發生改變,不能再被modify。後面的ReLU(inplace=True)恰恰就modify了x。所以反向傳播時報錯。
改變執行的先後順序呢?也不行。如果先ReLU(inplace=True),那麼x也被modify了,再做Residual Connection時輸入就不是block輸入的那個x了。
解決的辦法,改爲ReLU(inplace=False),或者Residual Connection的輸入改爲x.clone(),總之不能省內存……正確的代碼已經push到github上了,地址詳見文章開頭。
爲此,我寫了一個簡化的模型:
- class Wrong就是採坑的錯誤實現;
- class RightOne就是改爲ReLU(inplace=False);
- class RightTwo就是Residual Connection的輸入改爲x.clone();
一杯茶,一包煙,一個bug改一天……
import torch
import torch.nn as nn
class Wrong(nn.Module):
def __init__(self):
super(Wrong, self).__init__()
self.convs = nn.Sequential(nn.ReLU(inplace=True),
nn.Conv2d(3, 3, 3, padding=1))
self.residual = nn.Conv2d(3, 3, 3, padding=1)
pass
def forward(self, x):
r = self.residual(x) # 卷積之後,x就不能modify了
h = self.convs(x) # relu就modify了x,反向傳播時候會報錯
h = h + r
return h
pass
class RightOne(nn.Module):
def __init__(self):
super(RightOne, self).__init__()
self.convs = nn.Sequential(nn.ReLU(inplace=False), # 改法1,別省內存了
nn.Conv2d(3, 3, 3, padding=1))
self.residual = nn.Conv2d(3, 3, 3, padding=1)
pass
def forward(self, x):
r = self.residual(x)
h = self.convs(x)
h = h + r
return h
pass
class RightTwo(nn.Module):
def __init__(self):
super(RightTwo, self).__init__()
self.convs = nn.Sequential(nn.ReLU(inplace=True),
nn.Conv2d(3, 3, 3, padding=1))
self.residual = nn.Conv2d(3, 3, 3, padding=1)
pass
def forward(self, x):
r = self.residual(x.clone()) # 改法2,clone還是消耗內存的
h = self.convs(x)
h = h + r
return h
pass
if __name__ == '__main__':
in_data = torch.randint(-2, 2, (1, 3, 2, 2), dtype=torch.float)
in_label = torch.randint(0, 3, (1, 2, 2))
print(in_data.shape)
func = nn.CrossEntropyLoss()
t = RightTwo()
in_data = in_data.cuda()
in_label = in_label.cuda()
t.cuda()
out_data = t(in_data)
print(out_data.shape)
loss = func(out_data, in_label)
loss.backward()