ResidualAttentionNetwork——Tensorflow+keras實現

網絡結構

網絡整體預覽

在這裏插入圖片描述

Residual Unit

在這裏插入圖片描述

Attention Module

在這裏插入圖片描述

from tensorflow import keras as k
from tensorflow.contrib import layers
import tensorflow as tf
from bases.base_network import BaseNetwork
from utils.logger import logger


class ResidualAttentionNetwork(BaseNetwork):
    def __init__(self, inputs, is_training=True):
        super(ResidualAttentionNetwork, self).__init__(inputs, is_training)

    def _setup(self):
        # Pre convolution
        self.pre_conv = k.layers.Conv2D(64, kernel_size=7, strides=(2, 2), padding="SAME", name="pre_conv")(self.inputs)
        logger("pre_conv: {}".format(self.pre_conv.shape))
        # Pre max pooling
        self.pre_pool = k.layers.MaxPool2D(pool_size=(3, 3), strides=(2, 2), padding="SAME", name="pre_pool")(self.pre_conv)
        logger("pre_pool: {}".format(self.pre_pool.shape))
        # Pre_res_1
        self.pre_res_1 = self.residual_unit(self.pre_pool, 64, 256, "pre_res_1")
        logger("pre_res_1: {}".format(self.pre_res_1.shape))
        # Attention_1
        self.attention_1 = self.attention_module(self.pre_res_1, 256, 256, "attention_1", skip_num=2)
        logger("attention_1: {}".format(self.attention_1.shape))
        # Pre_res_2
        self.pre_res_2 = self.residual_unit(self.attention_1, 256, 512, stride=2, name="pre_res_2")
        logger("pre_res_2: {}".format(self.pre_res_2.shape))
        # Attention_2
        self.attention_2 = self.attention_module(self.pre_res_2, 512, 512, "attention_2", skip_num=1)
        logger("attention_2: {}".format(self.attention_2.shape))
        # Pre_res_3
        self.pre_res_3 = self.residual_unit(self.attention_2, 512, 1024, stride=2, name="pre_res_3")
        logger("pre_res_3: {}".format(self.pre_res_3.shape))
        # Attention_3
        self.attention_3 = self.attention_module(self.pre_res_3, 1024, 1024, "attention_3", skip_num=0)
        logger("attention_3: {}".format(self.attention_3.shape))
        # Pre_res_4
        self.pre_res_4 = self.residual_unit(self.attention_3, 1024, 2048, stride=2, name="pre_res_4")
        logger("pre_res_4: {}".format(self.pre_res_4.shape))
        # Average pooling
        self.ave_pool = k.layers.AveragePooling2D(pool_size=(7, 7), strides=(1, 1), name="ave_pool")(self.pre_res_4)
        logger("ave_pool: {}".format(self.ave_pool.shape))
        # Reshape
        pool_shape = self.ave_pool.get_shape().as_list()
        logger("pool_shape: {}".format(pool_shape))
        fc_input = k.layers.Reshape(target_shape=[pool_shape[1] * pool_shape[2] * pool_shape[3]],
                                    name="reshape"
                                    )(self.ave_pool)
        logger("fc_input: {}".format(fc_input.shape))
        # Fully connection
        self.outputs = k.layers.Dense(2)(fc_input)
        logger("fc: {}".format(self.outputs.shape))

    def attention_module(self, x, c_in, c_out, name, p=1, t=2, r=1, skip_num=2):
        """
        Attention模塊
        """
        # trunk1
        with tf.name_scope(name):
            # Trunk1
            with tf.name_scope("pre_trunk"), tf.variable_scope("pre_trunk"):
                pre_trunk = x
                for idx in range(p):
                    unit_name = "trunk_1_{}".format(idx + 1)
                    pre_trunk = self.residual_unit(pre_trunk, c_in, c_out, unit_name)

            # trunk branch
            with tf.name_scope("trunk_branch"):
                trunks = pre_trunk
                for idx in range(t):
                    unit_name = "trunk_res_{}".format(idx + 1)
                    trunks = self.residual_unit(trunks, c_in, c_out, unit_name)

            # mask branch
            with tf.name_scope("mask_branch"):
                size_1 = pre_trunk.get_shape().as_list()[1:3]

                # max pooling
                max_pool_1 = k.layers.MaxPool2D(pool_size=(3, 3), strides=(2, 2), padding="SAME", name="pool_1")(pre_trunk)

                # res -> skip -> wait to connection with interp
                #     |
                #    max -> res -> skip -> ...
                #               |
                #              max -> res
                down_res = max_pool_1
                skips = []
                sizes = []
                for skip_idx in range(skip_num):
                    # r residual unit
                    for idx in range(r):
                        unit_name = "down_res{}_{}".format(skip_idx + 1, idx + 1)
                        down_res = self.residual_unit(down_res, c_in, c_out, unit_name)

                    # skip
                    skip_res = self.residual_unit(down_res, c_in, c_out, name="skip_res_{}".format(skip_idx + 1))
                    skips.append(skip_res)

                    # max
                    size = down_res.get_shape().as_list()[1:3]
                    sizes.append(size)
                    down_res = k.layers.MaxPool2D(pool_size=(3, 3), strides=(2, 2), padding="SAME",
                                                  name="res_pool_{}".format(skip_idx)
                                                  )(down_res)

                # middle res
                midlle_res = down_res
                for idx in range(2 * r):
                    unit_name = "down_res{}_{}".format(skip_num + 1, idx + 1)
                    midlle_res = self.residual_unit(midlle_res, c_in, c_out, unit_name)

                # Reverse skips: [skip1, skip2] -> [skip2, skip1]
                skips.reverse()
                # Reverse sizes: [res_size1, res_size2] -> [res_size2, ress_size1]
                sizes.reverse()

                up_res = midlle_res
                # res-upsample and skip connection
                # interp + skip -> res
                for skip_idx, data in enumerate(zip(skips, sizes)):
                    skip, size = data
                    # interp
                    interp = tf.image.resize_bilinear(up_res, size, name="interp_".format(skip_num + 1 - skip_idx))
                    # skip connection
                    up_res = skip + interp
                    # res
                    for idx in range(r):
                        unit_name = "up_res{}_{}".format(skip_num - skip_idx, idx + 1)
                        up_res = self.residual_unit(up_res, c_in, c_out, unit_name)

                # Interp
                interp = tf.image.resize_bilinear(up_res, size_1, name="interp_1")

                # Batch Normalization
                mask_bn_1 = k.layers.BatchNormalization(name="mask_bn_1")(interp, self.is_training)

                # Linear(1x1 Convolution)
                linear_1 = k.layers.Conv2D(c_out,
                                           kernel_size=1,
                                           strides=(1, 1),
                                           name="linear_1",
                                           activation="relu",
                                           activity_regularizer=layers.l2_regularizer(scale=0.001, scope="linear_1_l2")
                                           )(mask_bn_1)

                # Batch Normalization
                mask_bn_2 = k.layers.BatchNormalization(name="mask_bn_2")(linear_1, self.is_training)

                # Linear(1x1 Convolution)
                linear_2 = k.layers.Conv2D(c_out,
                                           kernel_size=1,
                                           strides=(1, 1),
                                           name="linear_2",
                                           activation="relu",
                                           activity_regularizer=layers.l2_regularizer(scale=0.001, scope="linear2_l2")
                                           )(mask_bn_2)
                # Sigmoid
                sigmoid = tf.nn.sigmoid(linear_2, "mask_sigmoid")

            # Fusing
            with tf.name_scope("fusing"):
                outputs = k.layers.Multiply(name="fusing")([trunks, sigmoid])
                outputs = k.layers.Add(name="fuse_add")([outputs, trunks])

            # last trunks
            with tf.name_scope("last_trunk"), tf.variable_scope("last_trunk"):
                for idx in range(p):
                    unit_name = "last_trunk_{}".format(idx + 1)
                    outputs = self.residual_unit(outputs, c_in, c_out, unit_name)
            return outputs

    def residual_unit(self, x, c_in, c_out, name, stride=1, padding="SAME", scale=0.001):
        """
        Residual Unit
        """
        with tf.name_scope(name):
            """
            tf.name_scope()用來管理命名空間
            tf.get_variable()創建的共享變量不起作用
            tf.name_scope()對tf.get_variable()創建的變量不起作用
            
            BatchNormalization: 在每一個批次的數據中標準化前一層的激活項即,
            應用一個維持激活項平均值接近0,標準方差接近1的轉換
            
            這裏有個大坑:
            tf.layers.BatchNormalization和tf.layers.batch_normalization會自動將
            update_ops添加到tf.GraphKeys.UPDATE_OPS這個collection中,當training=True時纔會添加;
            而tf.keras.layers.BatchNormalization不會自動將update_ops添加到tf.GraphKeys.UPDATE_OPS這個collection中。
            所以在TensorFlow訓練session中使用tf.keras.layers.BatchNormalization時,
            需要手動將keras.BatchNormal層的updates添加到tf.GraphKeys.UPDATE_OPS這個collection中。
            
            在訓練時,要將BatchNormalization中的training參數設置爲True,測試時設置爲False,
            在保存模型時,要將\mu和\delta保存,這兩個參數存放在tf.global_variables()中
            var_list = tf.trainable_variables()
            g_list = tf.global_variables()
            bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name]
            bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name]
            var_list += bn_moving_vars
            saver = tf.train.Saver(var_list=var_list, max_to_keep=5)

            """
            # Batch Normaliaztion
            # bn_1 = tf.layers.batch_normalization(x, training=self.is_training, name="bn_1")
            bn_1 = k.layers.BatchNormalization(name="bn_1")(x, self.is_training)

            # 1x1 Convolution
            conv_1 = k.layers.Conv2D(c_out//4,
                                     kernel_size=1,
                                     strides=(1, 1),
                                     padding=padding,
                                     name="conv_1",
                                     activation="relu",
                                     activity_regularizer=layers.l2_regularizer(scale=scale, scope="conv_1_l2")
                                     )(bn_1)

            # Batch Normalization
            bn_2 = k.layers.BatchNormalization(name="bn_2")(conv_1, self.is_training)

            # 3x3 Convolution
            conv_2 = k.layers.Conv2D(c_out//4,
                                     kernel_size=3,
                                     strides=(stride, stride),
                                     padding=padding,
                                     name="conv_2",
                                     activation="relu",
                                     activity_regularizer=layers.l2_regularizer(scale=scale, scope="conv_2_l2")
                                     )(bn_2)

            # Batch Normalization
            bn_3 = k.layers.BatchNormalization(name="bn_3")(conv_2, self.is_training)

            # 1x1 Convolution
            # 不加激活函數
            conv_3 = k.layers.Conv2D(c_out,
                                     kernel_size=1,
                                     strides=(1, 1),
                                     padding=padding,
                                     name="conv_3",
                                     activation=None,
                                     activity_regularizer=layers.l2_regularizer(scale=scale, scope="conv_3_l2")
                                     )(bn_3)

            # Skip connection
            # 不加激活函數
            if c_out != c_in or stride > 1:
                skip = k.layers.Conv2D(c_out,
                                       kernel_size=1,
                                       strides=(stride, stride),
                                       padding=padding,
                                       name="conv_skip",
                                       activation=None,
                                       activity_regularizer=layers.l2_regularizer(scale=scale, scope="skip_l2")
                                       )(x)
            else:
                skip = x
            outputs = k.layers.Add(name="fuse")([conv_3, skip])
            return outputs

在這裏插入圖片描述
在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章