10分鐘看懂深度殘差收縮網絡

深度殘差網絡ResNet獲得了2016年IEEE Conference on Computer Vision and Pattern Recognition的最佳論文獎,目前在谷歌學術的引用量已高達38295次。

深度殘差收縮網絡是深度殘差網絡的一種的改進版本,其實是深度殘差網絡、注意力機制和軟閾值函數的集成

在一定程度上,深度殘差收縮網絡的工作原理,可以理解爲:通過注意力機制注意到不重要的特徵,通過軟閾值函數將它們置爲零;或者說,通過注意力機制注意到重要的特徵,將它們保留下來,從而加強深度神經網絡從含噪聲信號中提取有用特徵的能力。

1.爲什麼要提出深度殘差收縮網絡呢?

首先,在對樣本進行分類的時候,樣本中不可避免地會有一些噪聲,就像高斯噪聲、粉色噪聲、拉普拉斯噪聲等。更廣義地講,樣本中很可能包含着與當前分類任務無關的信息,這些信息也可以理解爲噪聲。這些噪聲可能會對分類效果產生不利的影響。(軟閾值化是許多信號降噪算法中的一個關鍵步驟)

舉例來說,在馬路邊聊天的時候,聊天的聲音裏就可能會混雜車輛的鳴笛聲、車輪聲等等。當對這些聲音信號進行語音識別的時候,識別效果不可避免地會受到鳴笛聲、車輪聲的影響。從深度學習的角度來講,這些鳴笛聲、車輪聲所對應的特徵,就應該在深度神經網絡內部被刪除掉,以避免對語音識別的效果造成影響。

其次,即使是同一個樣本集,各個樣本的噪聲量也往往是不同的。(這和注意力機制有相通之處;以一個圖像樣本集爲例,各張圖片中目標物體所在的位置可能是不同的;注意力機制可以針對每一張圖片,注意到目標物體所在的位置)

例如,當訓練貓狗分類器的時候,對於標籤爲“狗”的5張圖像,第1張圖像可能同時包含着狗和老鼠,第2張圖像可能同時包含着狗和鵝,第3張圖像可能同時包含着狗和雞,第4張圖像可能同時包含着狗和驢,第5張圖像可能同時包含着狗和鴨子。我們在訓練貓狗分類器的時候,就不可避免地會受到老鼠、鵝、雞、驢和鴨子等無關物體的干擾,造成分類準確率下降。如果我們能夠注意到這些無關的老鼠、鵝、雞、驢和鴨子,將它們所對應的特徵刪除掉,就有可能提高貓狗分類器的準確率。

2.軟閾值化是許多信號降噪算法的核心步驟

軟閾值化,是很多信號降噪算法的核心步驟,將絕對值小於某個閾值的特徵刪除掉,將絕對值大於這個閾值的特徵朝着零的方向進行收縮。它可以通過以下公式來實現:

軟閾值化的輸出對於輸入的導數爲

由上可知,軟閾值化的導數要麼是1,要麼是0。這個性質是和ReLU激活函數是相同的。因此,軟閾值化也能夠減小深度學習算法遭遇梯度彌散和梯度爆炸的風險。

在軟閾值化函數中,閾值的設置必須符合兩個的條件: 第一,閾值是正數;第二,閾值不能大於輸入信號的最大值,否則輸出會全部爲零。

同時,閾值最好還能符合第三個條件:每個樣本應該根據自身的噪聲含量,有着自己獨立的閾值。

這是因爲,很多樣本的噪聲含量經常是不同的。例如經常會有這種情況,在同一個樣本集裏面,樣本A所含噪聲較少,樣本B所含噪聲較多。那麼,如果是在降噪算法裏進行軟閾值化的時候,樣本A就應該採用較大的閾值,樣本B就應該採用較小的閾值。在深度神經網絡中,雖然這些特徵和閾值失去了明確的物理意義,但是基本的道理還是相通的。也就是說,每個樣本應該根據自身的噪聲含量,有着自己獨立的閾值。

3.注意力機制

注意力機制在計算機視覺領域是比較容易理解的。動物的視覺系統可以快速掃描全部區域,發現目標物體,進而將注意力集中在目標物體上,以提取更多的細節,同時抑制無關信息。具體請參照注意力機制方面的文章。

Squeeze-and-Excitation Network(SENet)是一種較新的注意力機制下的深度學習方法。 在不同的樣本中,不同的特徵通道,在分類任務中的貢獻大小,往往是不同的。SENet採用一個小型的子網絡,獲得一組權重,進而將這組權重與各個通道的特徵分別相乘,以調整各個通道特徵的大小。這個過程,就可以認爲是在施加不同大小的注意力在各個特徵通道上。

在這種方式下,每一個樣本,都會有自己獨立的一組權重。換言之,任意的兩個樣本,它們的權重,都是不一樣的。在SENet中,獲得權重的具體路徑是,“全局池化→全連接層→ReLU函數→全連接層→Sigmoid函數”。

4.深度注意力機制下的軟閾值化

深度殘差收縮網絡借鑑了上述SENet的子網絡結構,以實現深度注意力機制下的軟閾值化。通過藍色框內的子網絡,就可以學習得到一組閾值,對各個特徵通道進行軟閾值化。

在這個子網絡中,首先對輸入特徵圖的所有特徵,求它們的絕對值。然後經過全局均值池化和平均,獲得一個特徵,記爲A。在另一條路徑中,全局均值池化之後的特徵圖,被輸入到一個小型的全連接網絡。這個全連接網絡以Sigmoid函數作爲最後一層,將輸出歸一化到0和1之間,獲得一個係數,記爲α。最終的閾值可以表示爲α×A。因此,閾值就是,一個0和1之間的數字×特徵圖的絕對值的平均。這種方式,不僅保證了閾值爲正,而且不會太大。

而且,不同的樣本就有了不同的閾值。因此,在一定程度上,可以理解成一種特殊的注意力機制:注意到與當前任務無關的特徵,通過軟閾值化,將它們置爲零;或者說,注意到與當前任務有關的特徵,將它們保留下來。

最後,堆疊一定數量的基本模塊以及卷積層、批標準化、激活函數、全局均值池化以及全連接輸出層等,就得到了完整的深度殘差收縮網絡。

5.深度殘差收縮網絡或許有更廣泛的通用性

深度殘差收縮網絡事實上是一種通用的特徵學習方法。這是因爲很多特徵學習的任務中,樣本中或多或少都會包含一些噪聲,以及不相關的信息。這些噪聲和不相關的信息,有可能會對特徵學習的效果造成影響。例如說:

在圖片分類的時候,如果圖片同時包含着很多其他的物體,那麼這些物體就可以被理解成“噪聲”;深度殘差收縮網絡或許能夠藉助注意力機制,注意到這些“噪聲”,然後藉助軟閾值化,將這些“噪聲”所對應的特徵置爲零,就有可能提高圖像分類的準確率。

在語音識別的時候,如果在聲音較爲嘈雜的環境裏,比如在馬路邊、工廠車間裏聊天的時候,深度殘差收縮網絡也許可以提高語音識別的準確率,或者給出了一種能夠提高語音識別準確率的思路。

6.Keras和TFLearn程序簡介

本程序以圖像分類爲例,構建了小型的深度殘差收縮網絡,超參數也未進行優化。爲追求高準確率的話,可以適當增加深度,增加訓練迭代次數,以及適當調整超參數。下面是Keras程序:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Dec 28 23:24:05 2019
Implemented using TensorFlow 1.0.1 and Keras 2.2.1
 
M. Zhao, S. Zhong, X. Fu, et al., Deep Residual Shrinkage Networks for Fault Diagnosis, 
IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898
@author: super_9527
"""

from __future__ import print_function
import keras
import numpy as np
from keras.datasets import mnist
from keras.layers import Dense, Conv2D, BatchNormalization, Activation
from keras.layers import AveragePooling2D, Input, GlobalAveragePooling2D
from keras.optimizers import Adam
from keras.regularizers import l2
from keras import backend as K
from keras.models import Model
from keras.layers.core import Lambda
K.set_learning_phase(1)

# Input image dimensions
img_rows, img_cols = 28, 28

# The data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

if K.image_data_format() == 'channels_first':
    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
else:
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)

# Noised data
x_train = x_train.astype('float32') / 255. + 0.5*np.random.random([x_train.shape[0], img_rows, img_cols, 1])
x_test = x_test.astype('float32') / 255. + 0.5*np.random.random([x_test.shape[0], img_rows, img_cols, 1])
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)


def abs_backend(inputs):
    return K.abs(inputs)

def expand_dim_backend(inputs):
    return K.expand_dims(K.expand_dims(inputs,1),1)

def sign_backend(inputs):
    return K.sign(inputs)

def pad_backend(inputs, in_channels, out_channels):
    pad_dim = (out_channels - in_channels)//2
inputs = K.expand_dims(inputs,-1) inputs = K.spatial_3d_padding(inputs, ((0,0),(0,0),(pad_dim,pad_dim)), 'channels_last')
return K.squeeze(inputs, -1)

# Residual Shrinakge Block def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False, downsample_strides=2): residual = incoming in_channels = incoming.get_shape().as_list()[-1] for i in range(nb_blocks): identity = residual if not downsample: downsample_strides = 1 residual = BatchNormalization()(residual) residual = Activation('relu')(residual) residual = Conv2D(out_channels, 3, strides=(downsample_strides, downsample_strides), padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(residual) residual = BatchNormalization()(residual) residual = Activation('relu')(residual) residual = Conv2D(out_channels, 3, padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(residual) # Calculate global means residual_abs = Lambda(abs_backend)(residual) abs_mean = GlobalAveragePooling2D()(residual_abs) # Calculate scaling coefficients scales = Dense(out_channels, activation=None, kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(abs_mean) scales = BatchNormalization()(scales) scales = Activation('relu')(scales) scales = Dense(out_channels, activation='sigmoid', kernel_regularizer=l2(1e-4))(scales) scales = Lambda(expand_dim_backend)(scales) # Calculate thresholds thres = keras.layers.multiply([abs_mean, scales]) # Soft thresholding sub = keras.layers.subtract([residual_abs, thres]) zeros = keras.layers.subtract([sub, sub]) n_sub = keras.layers.maximum([sub, zeros]) residual = keras.layers.multiply([Lambda(sign_backend)(residual), n_sub]) # Downsampling (it is important to use the pooL-size of (1, 1)) if downsample_strides > 1: identity = AveragePooling2D(pool_size=(1,1), strides=(2,2))(identity) # Zero_padding to match channels (it is important to use zero padding rather than 1by1 convolution) if in_channels != out_channels: identity = Lambda(pad_backend, arguments={'in_channels':in_channels,'out_channels':out_channels})(identity) residual = keras.layers.add([residual, identity]) return residual # define and train a model inputs = Input(shape=input_shape) net = Conv2D(8, 3, padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(inputs) net = residual_shrinkage_block(net, 1, 8, downsample=True) net = BatchNormalization()(net) net = Activation('relu')(net) net = GlobalAveragePooling2D()(net) outputs = Dense(10, activation='softmax', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(net) model = Model(inputs=inputs, outputs=outputs) model.compile(loss='categorical_crossentropy', optimizer=Adam(), metrics=['accuracy']) model.fit(x_train, y_train, batch_size=100, epochs=5, verbose=1, validation_data=(x_test, y_test)) # get results K.set_learning_phase(0) DRSN_train_score = model.evaluate(x_train, y_train, batch_size=100, verbose=0) print('Train loss:', DRSN_train_score[0]) print('Train accuracy:', DRSN_train_score[1]) DRSN_test_score = model.evaluate(x_test, y_test, batch_size=100, verbose=0) print('Test loss:', DRSN_test_score[0]) print('Test accuracy:', DRSN_test_score[1])

下面是TFLearn程序:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Dec 23 21:23:09 2019
Implemented using TensorFlow 1.0 and TFLearn 0.3.2
 
M. Zhao, S. Zhong, X. Fu, B. Tang, M. Pecht, Deep Residual Shrinkage Networks for Fault Diagnosis, 
IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898
 
@author: super_9527
"""
  
from __future__ import division, print_function, absolute_import
  
import tflearn
import numpy as np
import tensorflow as tf
from tflearn.layers.conv import conv_2d
  
# Data loading
from tflearn.datasets import cifar10
(X, Y), (testX, testY) = cifar10.load_data()
  
# Add noise
X = X + np.random.random((50000, 32, 32, 3))*0.1
testX = testX + np.random.random((10000, 32, 32, 3))*0.1
  
# Transform labels to one-hot format
Y = tflearn.data_utils.to_categorical(Y,10)
testY = tflearn.data_utils.to_categorical(testY,10)
  
def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False,
                   downsample_strides=2, activation='relu', batch_norm=True,
                   bias=True, weights_init='variance_scaling',
                   bias_init='zeros', regularizer='L2', weight_decay=0.0001,
                   trainable=True, restore=True, reuse=False, scope=None,
                   name="ResidualBlock"):
      
    # residual shrinkage blocks with channel-wise thresholds
  
    residual = incoming
    in_channels = incoming.get_shape().as_list()[-1]
  
    # Variable Scope fix for older TF
    try:
        vscope = tf.variable_scope(scope, default_name=name, values=[incoming],
                                   reuse=reuse)
    except Exception:
        vscope = tf.variable_op_scope([incoming], scope, name, reuse=reuse)
  
    with vscope as scope:
        name = scope.name #TODO
  
        for i in range(nb_blocks):
  
            identity = residual
  
            if not downsample:
                downsample_strides = 1
  
            if batch_norm:
                residual = tflearn.batch_normalization(residual)
            residual = tflearn.activation(residual, activation)
            residual = conv_2d(residual, out_channels, 3,
                             downsample_strides, 'same', 'linear',
                             bias, weights_init, bias_init,
                             regularizer, weight_decay, trainable,
                             restore)
  
            if batch_norm:
                residual = tflearn.batch_normalization(residual)
            residual = tflearn.activation(residual, activation)
            residual = conv_2d(residual, out_channels, 3, 1, 'same',
                             'linear', bias, weights_init,
                             bias_init, regularizer, weight_decay,
                             trainable, restore)
              
            # get thresholds and apply thresholding
            abs_mean = tf.reduce_mean(tf.reduce_mean(tf.abs(residual),axis=2,keep_dims=True),axis=1,keep_dims=True)
            scales = tflearn.fully_connected(abs_mean, out_channels//4, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling')
            scales = tflearn.batch_normalization(scales)
            scales = tflearn.activation(scales, 'relu')
            scales = tflearn.fully_connected(scales, out_channels, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling')
            scales = tf.expand_dims(tf.expand_dims(scales,axis=1),axis=1)
            thres = tf.multiply(abs_mean,tflearn.activations.sigmoid(scales))
            # soft thresholding
            residual = tf.multiply(tf.sign(residual), tf.maximum(tf.abs(residual)-thres,0))
              
  
            # Downsampling
            if downsample_strides > 1:
                identity = tflearn.avg_pool_2d(identity, 1,
                                               downsample_strides)
  
            # Projection to new dimension
            if in_channels != out_channels:
                if (out_channels - in_channels) % 2 == 0:
                    ch = (out_channels - in_channels)//2
                    identity = tf.pad(identity,
                                      [[0, 0], [0, 0], [0, 0], [ch, ch]])
                else:
                    ch = (out_channels - in_channels)//2
                    identity = tf.pad(identity,
                                      [[0, 0], [0, 0], [0, 0], [ch, ch+1]])
                in_channels = out_channels
  
            residual = residual + identity
  
    return residual
  
  
# Real-time data preprocessing
img_prep = tflearn.ImagePreprocessing()
img_prep.add_featurewise_zero_center(per_channel=True)
  
# Real-time data augmentation
img_aug = tflearn.ImageAugmentation()
img_aug.add_random_flip_leftright()
img_aug.add_random_crop([32, 32], padding=4)
  
# Build a Deep Residual Shrinkage Network with 3 blocks
net = tflearn.input_data(shape=[None, 32, 32, 3],
                         data_preprocessing=img_prep,
                         data_augmentation=img_aug)
net = tflearn.conv_2d(net, 16, 3, regularizer='L2', weight_decay=0.0001)
net = residual_shrinkage_block(net, 1, 16)
net = residual_shrinkage_block(net, 1, 32, downsample=True)
net = residual_shrinkage_block(net, 1, 32, downsample=True)
net = tflearn.batch_normalization(net)
net = tflearn.activation(net, 'relu')
net = tflearn.global_avg_pool(net)
# Regression
net = tflearn.fully_connected(net, 10, activation='softmax')
mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=20000, staircase=True)
net = tflearn.regression(net, optimizer=mom, loss='categorical_crossentropy')
# Training
model = tflearn.DNN(net, checkpoint_path='model_cifar10',
                    max_checkpoints=10, tensorboard_verbose=0,
                    clip_gradients=0.)
  
model.fit(X, Y, n_epoch=100, snapshot_epoch=False, snapshot_step=500,
          show_metric=True, batch_size=100, shuffle=True, run_id='model_cifar10')
  
training_acc = model.evaluate(X, Y)[0]
validation_acc = model.evaluate(testX, testY)[0]

 

論文網址

M. Zhao, S. Zhong, X. Fu, et al., Deep residual shrinkage networks for fault diagnosis, IEEE Transactions on Industrial Informatics, DOI: 10.1109/TII.2019.2943898

https://ieeexplore.ieee.org/document/8850096

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章