tensorflow2製作Resnet殘差網絡

使用tensorflow2復現 Resnet v1&v2

關於resnet的原理、解析參考論文以及網絡博客, 這裏就不再複述了, 我主要是看其他框架版本的resnet進行理解, 畢竟文字描述不清楚, 細節都在代碼中體現

導入相應的包

這裏使用的tensorflow2

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, Dense
from tensorflow.keras.layers import Input, Flatten, AveragePooling2D

構建網絡

先看看兩個版本的區別

如下圖片對resnet v1還是v2的區別描述的很清晰, 根據此圖, 創建網絡模塊, 當然把兩個版本都寫出來.
resnet結構圖片
可以看出,V1 V2 的區別就在於 (V1)Conv->BN->Activation 還是 (V2)BN->Activation->Conv
----------------------------------------------------------分割線----------------------------------------------------------

實現基礎模塊

如下圖, 我個人認爲,圖中的 weight layer 的具體內容就是上述的 (V1)Conv->BN->Activation 或者 (V2)BN->Activation->Conv (注意!一個weight layer中不一定都存在這三者), 因此讓我們實現這一小塊的內容。Resnet的Shortcut

def weight_layer(inputs,
                 filters=16,
                 kernel_size=3,
                 strides=1,
                 activation='relu', 
                 batch_normalization=True, 
                 conv_first=True):
    '''
    這裏僅僅只是實現殘差中的 (V1)Conv->BN->Activation 或者 (V2)BN->Activation->Conv
    args:
        inputs(tensor):     最初輸入的圖片或者上一層輸入的圖片
        filters(int):       卷積核的數量
        kernel_size(int):   卷積核的大小
        strides(int):       卷積核移動的距離
        activation(str):    激活函數
        batch_normalization(bool): 是否使用batch_normalization
        conv_first(bool):   Resnet使用版本, resnetv1(True), resnetv2(False)
    '''
    conv = Conv2D(filters, 
                  kernel_size=kernel_size,
                  strides=strides, 
                  padding='same', 
                  kernel_initializer='he_normal',   
                  kernel_regularizer=keras.regularizers.l2(1e-4))
    
    x = inputs
    if conv_first:
        # Resnet v1
        x = conv(x)
        if batch_normalization:     # 可以隨意選擇是否添加BN層
            x = BatchNormalization()(x)
        if activation is not None:  # 可以隨意選擇是否添加激活函數層
            x = Activation(activation)(x)
    else:
        # Resnet v2
        if batch_normalization:
            x = BatchNormalization()(x)
        if activation is not None:
            x = Activation(activation)(x)
        x = conv(x)
    return x

接下來實現 Resnet V1 的網絡結構,每個小塊的結構都與第一張圖裏的一致,爲了便於看代碼我把網絡結構圖(depth=8)貼在下面:
Resnet V1

def resnet_v1(input_shape, depth, num_classes=10):
    '''
    實現Resnet V1網絡

    args:
        input_shape(tube): 輸入圖像的shape,例如(128, 128, 3)
        depth(int):        網絡的深度
        num_classes(int):  分類器輸出結果的種類 
    
    return:
        model:            Resnet V1網絡模型
    '''
    # depth必須是6n+2 
    # 2 是指剛開始的weight layers塊以及最後的分類器
    # 6n 是因爲: stack一共三個,每個stack中有兩個res_blocks (具體看下面的循環)
    if (depth - 2) % 6 != 0:
        raise ValueError('網絡深度必須是 6n+2!')
    num_filters = 16
    num_res_blocks = int((depth - 2) / 6)

    inputs = Input(shape=input_shape)
    x = weight_layer(inputs)

    for stack in range(3):
        for res_block in range(num_res_blocks):
            strides = 1
            if stack > 0 and res_block == 0:
                strides = 2

            y = weight_layer(x, filters=num_filters, strides=strides)
            # 這裏只添加了一個卷積層以及一個BN層
            y = weight_layer(y, filters=num_filters, activation=None)
            # 緯度不對應時(因strides=2引起的),使用size=1的卷積核進行調整,讓x的緯度變大以便與y相加(通過num_filters讓緯度相同)
            # 這裏的1*1的卷積更多的起得是整合緯度的作用
            if stack > 0 and res_block == 0:
                x = weight_layer(x, 
                filters=num_filters, 
                kernel_size=1, 
                strides=strides, 
                activation=None, 
                batch_normalization=False)

            x = keras.layers.add([x, y])
            x = Activation('relu')(x)
        num_filters *= 2
    
    # 添加分類器
    x = AveragePooling2D(pool_size=8)(x)
    y = Flatten()(x)
    outputs = Dense(num_classes, activation='softmax', kernel_initializer='he_normal')(y)

    model = keras.Model(inputs=inputs, outputs=outputs)
    return model

未完待續

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章