DenseNet:比ResNet更優的CNN模型

前言

在計算機視覺領域,卷積神經網絡(CNN)已經成爲最主流的方法,比如最近的GoogLenet,VGG-19,Incepetion等模型。CNN史上的一個里程碑事件是ResNet模型的出現,ResNet可以訓練出更深的CNN模型,從而實現更高的準確度。ResNet模型的核心是通過建立前面層與後面層之間的“短路連接”(shortcuts,skip connection),這有助於訓練過程中梯度的反向傳播,從而能訓練出更深的CNN網絡。今天我們要介紹的是DenseNet模型,它的基本思路與ResNet一致,但是它建立的是前面所有層與後面層的密集連接(dense connection),它的名稱也是由此而來。DenseNet的另一大特色是通過特徵在channel上的連接來實現特徵重用(feature reuse)。這些特點讓DenseNet在參數和計算成本更少的情形下實現比ResNet更優的性能,DenseNet也因此斬獲CVPR 2017的最佳論文獎。本篇文章首先介紹DenseNet的原理以及網路架構,然後講解DenseNet在Pytorch上的實現。

 

設計理念

相比ResNet,DenseNet提出了一個更激進的密集連接機制:即互相連接所有的層,具體來說就是每個層都會接受其前面所有層作爲其額外的輸入。圖1爲ResNet網絡的連接機制,作爲對比,圖2爲DenseNet的密集連接機制。可以看到,ResNet是每個層與前面的某層(一般是2~3層)短路連接在一起,連接方式是通過元素級相加。而在DenseNet中,每個層都會與前面所有層在channel維度上連接(concat)在一起(這裏各個層的特徵圖大小是相同的,後面會有說明),並作爲下一層的輸入。對於一個 L 層的網絡,DenseNet共包含 \frac{L(L+1)}{2} 個連接,相比ResNet,這是一種密集連接。而且DenseNet是直接concat來自不同層的特徵圖,這可以實現特徵重用,提升效率,這一特點是DenseNet與ResNet最主要的區別。

圖1 ResNet網絡的短路連接機制(其中+代表的是元素級相加操作)

圖2 DenseNet網絡的密集連接機制(其中c代表的是channel級連接操作)

如果用公式表示的話,傳統的網絡在 l 層的輸出爲:
\\x_l = H_l(x_{l-1})

而對於ResNet,增加了來自上一層輸入的identity函數:
\\x_l = H_l(x_{l-1}) + x_{l-1}

在DenseNet中,會連接前面所有層作爲輸入:
\\x_l = H_l([x_0, x_1, ..., x_{l-1}])

其中,上面的 H_l(\cdot) 代表是非線性轉化函數(non-liear transformation),它是一個組合操作,其可能包括一系列的BN(Batch Normalization),ReLU,Pooling及Conv操作。注意這裏 l 層與 l-1 層之間可能實際上包含多個卷積層。

DenseNet的前向過程如圖3所示,可以更直觀地理解其密集連接方式,比如 h_3 的輸入不僅包括來自 h_2 的 x_2 ,還包括前面兩層的 x_1 和 x_2 ,它們是在channel維度上連接在一起的。

圖3 DenseNet的前向過程

CNN網絡一般要經過Pooling或者stride>1的Conv來降低特徵圖的大小,而DenseNet的密集連接方式需要特徵圖大小保持一致。爲了解決這個問題,DenseNet網絡中使用DenseBlock+Transition的結構,其中DenseBlock是包含很多層的模塊,每個層的特徵圖大小相同,層與層之間採用密集連接方式。而Transition模塊是連接兩個相鄰的DenseBlock,並且通過Pooling使特徵圖大小降低。圖4給出了DenseNet的網路結構,它共包含4個DenseBlock,各個DenseBlock之間通過Transition連接在一起。

圖4 使用DenseBlock+Transition的DenseNet網絡

網絡結構

如前所示,DenseNet的網絡結構主要由DenseBlock和Transition組成,如圖5所示。下面具體介紹網絡的具體實現細節。

圖6 DenseNet的網絡結構

在DenseBlock中,各個層的特徵圖大小一致,可以在channel維度上連接。DenseBlock中的非線性組合函數 H(\cdot) 採用的是BN+ReLU+3x3 Conv的結構,如圖6所示。另外值得注意的一點是,與ResNet不同,所有DenseBlock中各個層卷積之後均輸出 k 個特徵圖,即得到的特徵圖的channel數爲 k ,或者說採用 k 個卷積核。 k 在DenseNet稱爲growth rate,這是一個超參數。一般情況下使用較小的 k (比如12),就可以得到較佳的性能。假定輸入層的特徵圖的channel數爲 k_0 ,那麼 l 層輸入的channel數爲 k_0+k(l-1) ,因此隨着層數增加,儘管 k 設定得較小,DenseBlock的輸入會非常多,不過這是由於特徵重用所造成的,每個層僅有 k 個特徵是自己獨有的。

圖6 DenseBlock中的非線性轉換結構

由於後面層的輸入會非常大,DenseBlock內部可以採用bottleneck層來減少計算量,主要是原有的結構中增加1x1 Conv,如圖7所示,即BN+ReLU+1x1 Conv+BN+ReLU+3x3 Conv,稱爲DenseNet-B結構。其中1x1 Conv得到 4k 個特徵圖它起到的作用是降低特徵數量,從而提升計算效率。

圖7 使用bottleneck層的DenseBlock結構

對於Transition層,它主要是連接兩個相鄰的DenseBlock,並且降低特徵圖大小。Transition層包括一個1x1的卷積和2x2的AvgPooling,結構爲BN+ReLU+1x1 Conv+2x2 AvgPooling。另外,Transition層可以起到壓縮模型的作用。假定Transition的上接DenseBlock得到的特徵圖channels數爲 m ,Transition層可以產生 \lfloor\theta m\rfloor 個特徵(通過卷積層),其中 \theta \in (0,1] 是壓縮係數(compression rate)。當 \theta=1 時,特徵個數經過Transition層沒有變化,即無壓縮,而當壓縮係數小於1時,這種結構稱爲DenseNet-C,文中使用 \theta=0.5 。對於使用bottleneck層的DenseBlock結構和壓縮係數小於1的Transition組合結構稱爲DenseNet-BC。

DenseNet共在三個圖像分類數據集(CIFAR,SVHN和ImageNet)上進行測試。對於前兩個數據集,其輸入圖片大小爲 32\times 32 ,所使用的DenseNet在進入第一個DenseBlock之前,首先進行進行一次3x3卷積(stride=1),卷積核數爲16(對於DenseNet-BC爲 2k )。DenseNet共包含三個DenseBlock,各個模塊的特徵圖大小分別爲 32\times 32 , 16\times 16 和 8\times 8 ,每個DenseBlock裏面的層數相同。最後的DenseBlock之後是一個global AvgPooling層,然後送入一個softmax分類器。注意,在DenseNet中,所有的3x3卷積均採用padding=1的方式以保證特徵圖大小維持不變。對於基本的DenseNet,使用如下三種網絡配置: \{L=40, k=12\} , \{L=100, k=12\} , \{L=40, k=24\} 。而對於DenseNet-BC結構,使用如下三種網絡配置: \{L=100, k=12\} , \{L=250, k=24\} , \{L=190, k=40\} 。這裏的 L 指的是網絡總層數(網絡深度),一般情況下,我們只把帶有訓練參數的層算入其中,而像Pooling這樣的無參數層不納入統計中,此外BN層儘管包含參數但是也不單獨統計,而是可以計入它所附屬的卷積層。對於普通的 {L=40, k=12} 網絡,除去第一個卷積層、2個Transition中卷積層以及最後的Linear層,共剩餘36層,均分到三個DenseBlock可知每個DenseBlock包含12層。其它的網絡配置同樣可以算出各個DenseBlock所含層數。

對於ImageNet數據集,圖片輸入大小爲 224\times 224 ,網絡結構採用包含4個DenseBlock的DenseNet-BC,其首先是一個stride=2的7x7卷積層(卷積核數爲 2k ),然後是一個stride=2的3x3 MaxPooling層,後面才進入DenseBlock。ImageNet數據集所採用的網絡配置如表1所示:

表1 ImageNet數據集上所採用的DenseNet結構

實驗結果及討論

這裏給出DenseNet在CIFAR-100和ImageNet數據集上與ResNet的對比結果,如圖8和9所示。從圖8中可以看到,只有0.8M的DenseNet-100性能已經超越ResNet-1001,並且後者參數大小爲10.2M。而從圖9中可以看出,同等參數大小時,DenseNet也優於ResNet網絡。其它實驗結果見原論文。

圖8 在CIFAR-100數據集上ResNet vs DenseNet

圖9 在ImageNet數據集上ResNet vs DenseNet

綜合來看,DenseNet的優勢主要體現在以下幾個方面:

  • 由於密集連接方式,DenseNet提升了梯度的反向傳播,使得網絡更容易訓練。由於每層可以直達最後的誤差信號,實現了隱式的“deep supervision”
  • 參數更小且計算更高效,這有點違反直覺,由於DenseNet是通過concat特徵來實現短路連接,實現了特徵重用,並且採用較小的growth rate,每個層所獨有的特徵圖是比較小的;
  • 由於特徵複用,最後的分類器使用了低級特徵。

要注意的一點是,如果實現方式不當的話,DenseNet可能耗費很多GPU顯存,一種高效的實現如圖10所示,更多細節可以見這篇論文Memory-Efficient Implementation of DenseNets。不過我們下面使用Pytorch框架可以自動實現這種優化。

圖10 DenseNet的更高效實現方式

使用Pytorch實現DenseNet

這裏我們採用Pytorch框架來實現DenseNet,目前它已經支持Windows系統。對於DenseNet,Pytorch在torchvision.models模塊裏給出了官方實現,這個DenseNet版本是用於ImageNet數據集的DenseNet-BC模型,下面簡單介紹實現過程。

首先實現DenseBlock中的內部結構,這裏是BN+ReLU+1x1 Conv+BN+ReLU+3x3 Conv結構,最後也加入dropout層以用於訓練過程。

class _DenseLayer(nn.Sequential):
    """Basic unit of DenseBlock (using bottleneck layer) """
    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
        super(_DenseLayer, self).__init__()
        self.add_module("norm1", nn.BatchNorm2d(num_input_features))
        self.add_module("relu1", nn.ReLU(inplace=True))
        self.add_module("conv1", nn.Conv2d(num_input_features, bn_size*growth_rate,
                                           kernel_size=1, stride=1, bias=False))
        self.add_module("norm2", nn.BatchNorm2d(bn_size*growth_rate))
        self.add_module("relu2", nn.ReLU(inplace=True))
        self.add_module("conv2", nn.Conv2d(bn_size*growth_rate, growth_rate,
                                           kernel_size=3, stride=1, padding=1, bias=False))
        self.drop_rate = drop_rate

    def forward(self, x):
        new_features = super(_DenseLayer, self).forward(x)
        if self.drop_rate > 0:
            new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
        return torch.cat([x, new_features], 1)

據此,實現DenseBlock模塊,內部是密集連接方式(輸入特徵數線性增長):

class _DenseBlock(nn.Sequential):
    """DenseBlock"""
    def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
        super(_DenseBlock, self).__init__()
        for i in range(num_layers):
            layer = _DenseLayer(num_input_features+i*growth_rate, growth_rate, bn_size,
                                drop_rate)
            self.add_module("denselayer%d" % (i+1,), layer)

此外,我們實現Transition層,它主要是一個卷積層和一個池化層:

class _Transition(nn.Sequential):
    """Transition layer between two adjacent DenseBlock"""
    def __init__(self, num_input_feature, num_output_features):
        super(_Transition, self).__init__()
        self.add_module("norm", nn.BatchNorm2d(num_input_feature))
        self.add_module("relu", nn.ReLU(inplace=True))
        self.add_module("conv", nn.Conv2d(num_input_feature, num_output_features,
                                          kernel_size=1, stride=1, bias=False))
        self.add_module("pool", nn.AvgPool2d(2, stride=2))

最後我們實現DenseNet網絡:

class DenseNet(nn.Module):
    "DenseNet-BC model"
    def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64,
                 bn_size=4, compression_rate=0.5, drop_rate=0, num_classes=1000):
        """
        :param growth_rate: (int) number of filters used in DenseLayer, `k` in the paper
        :param block_config: (list of 4 ints) number of layers in each DenseBlock
        :param num_init_features: (int) number of filters in the first Conv2d
        :param bn_size: (int) the factor using in the bottleneck layer
        :param compression_rate: (float) the compression rate used in Transition Layer
        :param drop_rate: (float) the drop rate after each DenseLayer
        :param num_classes: (int) number of classes for classification
        """
        super(DenseNet, self).__init__()
        # first Conv2d
        self.features = nn.Sequential(OrderedDict([
            ("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
            ("norm0", nn.BatchNorm2d(num_init_features)),
            ("relu0", nn.ReLU(inplace=True)),
            ("pool0", nn.MaxPool2d(3, stride=2, padding=1))
        ]))

        # DenseBlock
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(num_layers, num_features, bn_size, growth_rate, drop_rate)
            self.features.add_module("denseblock%d" % (i + 1), block)
            num_features += num_layers*growth_rate
            if i != len(block_config) - 1:
                transition = _Transition(num_features, int(num_features*compression_rate))
                self.features.add_module("transition%d" % (i + 1), transition)
                num_features = int(num_features * compression_rate)

        # final bn+ReLU
        self.features.add_module("norm5", nn.BatchNorm2d(num_features))
        self.features.add_module("relu5", nn.ReLU(inplace=True))

        # classification layer
        self.classifier = nn.Linear(num_features, num_classes)

        # params initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        features = self.features(x)
        out = F.avg_pool2d(features, 7, stride=1).view(features.size(0), -1)
        out = self.classifier(out)
        return out

選擇不同網絡參數,就可以實現不同深度的DenseNet,這裏實現DenseNet-121網絡,而且Pytorch提供了預訓練好的網絡參數:

def densenet121(pretrained=False, **kwargs):
    """DenseNet121"""
    model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16),
                     **kwargs)

    if pretrained:
        # '.'s are no longer allowed in module names, but pervious _DenseLayer
        # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
        # They are also in the checkpoints in model_urls. This pattern is used
        # to find such keys.
        pattern = re.compile(
            r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
        state_dict = model_zoo.load_url(model_urls['densenet121'])
        for key in list(state_dict.keys()):
            res = pattern.match(key)
            if res:
                new_key = res.group(1) + res.group(2)
                state_dict[new_key] = state_dict[key]
                del state_dict[key]
        model.load_state_dict(state_dict)
    return model

下面,我們使用預訓練好的網絡對圖片進行測試,這裏給出top-5預測值:

densenet = densenet121(pretrained=True)
densenet.eval()

img = Image.open("./images/cat.jpg")

trans_ops = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

images = trans_ops(img).view(-1, 3, 224, 224)
outputs = densenet(images)

_, predictions = outputs.topk(5, dim=1)

labels = list(map(lambda s: s.strip(), open("./data/imagenet/synset_words.txt").readlines()))
for idx in predictions.numpy()[0]:
    print("Predicted labels:", labels[idx])

給出的預測結果爲:

Predicted labels: n02123159 tiger cat
Predicted labels: n02123045 tabby, tabby cat
Predicted labels: n02127052 lynx, catamount
Predicted labels: n02124075 Egyptian cat
Predicted labels: n02119789 kit fox, Vulpes macrotis

注:完整代碼見xiaohu2015/DeepLearning_tutorials

小結

這篇文章詳細介紹了DenseNet的設計理念以及網絡結構,並給出瞭如何使用Pytorch來實現。值得注意的是,DenseNet在ResNet基礎上前進了一步,相比ResNet具有一定的優勢,但是其卻並沒有像ResNet那麼出名(吃顯存問題?深度不能太大?)。期待未來有更好的網絡模型出現吧!

參考文獻

  1. DenseNet-CVPR-Slides.
  2. Densely Connected Convolutional Networks.

 

轉自:https://zhuanlan.zhihu.com/p/37189203

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章