網絡結構
網絡整體預覽
Residual Unit
Attention Module
from tensorflow import keras as k
from tensorflow.contrib import layers
import tensorflow as tf
from bases.base_network import BaseNetwork
from utils.logger import logger
class ResidualAttentionNetwork(BaseNetwork):
def __init__(self, inputs, is_training=True):
super(ResidualAttentionNetwork, self).__init__(inputs, is_training)
def _setup(self):
# Pre convolution
self.pre_conv = k.layers.Conv2D(64, kernel_size=7, strides=(2, 2), padding="SAME", name="pre_conv")(self.inputs)
logger("pre_conv: {}".format(self.pre_conv.shape))
# Pre max pooling
self.pre_pool = k.layers.MaxPool2D(pool_size=(3, 3), strides=(2, 2), padding="SAME", name="pre_pool")(self.pre_conv)
logger("pre_pool: {}".format(self.pre_pool.shape))
# Pre_res_1
self.pre_res_1 = self.residual_unit(self.pre_pool, 64, 256, "pre_res_1")
logger("pre_res_1: {}".format(self.pre_res_1.shape))
# Attention_1
self.attention_1 = self.attention_module(self.pre_res_1, 256, 256, "attention_1", skip_num=2)
logger("attention_1: {}".format(self.attention_1.shape))
# Pre_res_2
self.pre_res_2 = self.residual_unit(self.attention_1, 256, 512, stride=2, name="pre_res_2")
logger("pre_res_2: {}".format(self.pre_res_2.shape))
# Attention_2
self.attention_2 = self.attention_module(self.pre_res_2, 512, 512, "attention_2", skip_num=1)
logger("attention_2: {}".format(self.attention_2.shape))
# Pre_res_3
self.pre_res_3 = self.residual_unit(self.attention_2, 512, 1024, stride=2, name="pre_res_3")
logger("pre_res_3: {}".format(self.pre_res_3.shape))
# Attention_3
self.attention_3 = self.attention_module(self.pre_res_3, 1024, 1024, "attention_3", skip_num=0)
logger("attention_3: {}".format(self.attention_3.shape))
# Pre_res_4
self.pre_res_4 = self.residual_unit(self.attention_3, 1024, 2048, stride=2, name="pre_res_4")
logger("pre_res_4: {}".format(self.pre_res_4.shape))
# Average pooling
self.ave_pool = k.layers.AveragePooling2D(pool_size=(7, 7), strides=(1, 1), name="ave_pool")(self.pre_res_4)
logger("ave_pool: {}".format(self.ave_pool.shape))
# Reshape
pool_shape = self.ave_pool.get_shape().as_list()
logger("pool_shape: {}".format(pool_shape))
fc_input = k.layers.Reshape(target_shape=[pool_shape[1] * pool_shape[2] * pool_shape[3]],
name="reshape"
)(self.ave_pool)
logger("fc_input: {}".format(fc_input.shape))
# Fully connection
self.outputs = k.layers.Dense(2)(fc_input)
logger("fc: {}".format(self.outputs.shape))
def attention_module(self, x, c_in, c_out, name, p=1, t=2, r=1, skip_num=2):
"""
Attention模塊
"""
# trunk1
with tf.name_scope(name):
# Trunk1
with tf.name_scope("pre_trunk"), tf.variable_scope("pre_trunk"):
pre_trunk = x
for idx in range(p):
unit_name = "trunk_1_{}".format(idx + 1)
pre_trunk = self.residual_unit(pre_trunk, c_in, c_out, unit_name)
# trunk branch
with tf.name_scope("trunk_branch"):
trunks = pre_trunk
for idx in range(t):
unit_name = "trunk_res_{}".format(idx + 1)
trunks = self.residual_unit(trunks, c_in, c_out, unit_name)
# mask branch
with tf.name_scope("mask_branch"):
size_1 = pre_trunk.get_shape().as_list()[1:3]
# max pooling
max_pool_1 = k.layers.MaxPool2D(pool_size=(3, 3), strides=(2, 2), padding="SAME", name="pool_1")(pre_trunk)
# res -> skip -> wait to connection with interp
# |
# max -> res -> skip -> ...
# |
# max -> res
down_res = max_pool_1
skips = []
sizes = []
for skip_idx in range(skip_num):
# r residual unit
for idx in range(r):
unit_name = "down_res{}_{}".format(skip_idx + 1, idx + 1)
down_res = self.residual_unit(down_res, c_in, c_out, unit_name)
# skip
skip_res = self.residual_unit(down_res, c_in, c_out, name="skip_res_{}".format(skip_idx + 1))
skips.append(skip_res)
# max
size = down_res.get_shape().as_list()[1:3]
sizes.append(size)
down_res = k.layers.MaxPool2D(pool_size=(3, 3), strides=(2, 2), padding="SAME",
name="res_pool_{}".format(skip_idx)
)(down_res)
# middle res
midlle_res = down_res
for idx in range(2 * r):
unit_name = "down_res{}_{}".format(skip_num + 1, idx + 1)
midlle_res = self.residual_unit(midlle_res, c_in, c_out, unit_name)
# Reverse skips: [skip1, skip2] -> [skip2, skip1]
skips.reverse()
# Reverse sizes: [res_size1, res_size2] -> [res_size2, ress_size1]
sizes.reverse()
up_res = midlle_res
# res-upsample and skip connection
# interp + skip -> res
for skip_idx, data in enumerate(zip(skips, sizes)):
skip, size = data
# interp
interp = tf.image.resize_bilinear(up_res, size, name="interp_".format(skip_num + 1 - skip_idx))
# skip connection
up_res = skip + interp
# res
for idx in range(r):
unit_name = "up_res{}_{}".format(skip_num - skip_idx, idx + 1)
up_res = self.residual_unit(up_res, c_in, c_out, unit_name)
# Interp
interp = tf.image.resize_bilinear(up_res, size_1, name="interp_1")
# Batch Normalization
mask_bn_1 = k.layers.BatchNormalization(name="mask_bn_1")(interp, self.is_training)
# Linear(1x1 Convolution)
linear_1 = k.layers.Conv2D(c_out,
kernel_size=1,
strides=(1, 1),
name="linear_1",
activation="relu",
activity_regularizer=layers.l2_regularizer(scale=0.001, scope="linear_1_l2")
)(mask_bn_1)
# Batch Normalization
mask_bn_2 = k.layers.BatchNormalization(name="mask_bn_2")(linear_1, self.is_training)
# Linear(1x1 Convolution)
linear_2 = k.layers.Conv2D(c_out,
kernel_size=1,
strides=(1, 1),
name="linear_2",
activation="relu",
activity_regularizer=layers.l2_regularizer(scale=0.001, scope="linear2_l2")
)(mask_bn_2)
# Sigmoid
sigmoid = tf.nn.sigmoid(linear_2, "mask_sigmoid")
# Fusing
with tf.name_scope("fusing"):
outputs = k.layers.Multiply(name="fusing")([trunks, sigmoid])
outputs = k.layers.Add(name="fuse_add")([outputs, trunks])
# last trunks
with tf.name_scope("last_trunk"), tf.variable_scope("last_trunk"):
for idx in range(p):
unit_name = "last_trunk_{}".format(idx + 1)
outputs = self.residual_unit(outputs, c_in, c_out, unit_name)
return outputs
def residual_unit(self, x, c_in, c_out, name, stride=1, padding="SAME", scale=0.001):
"""
Residual Unit
"""
with tf.name_scope(name):
"""
tf.name_scope()用來管理命名空間
tf.get_variable()創建的共享變量不起作用
tf.name_scope()對tf.get_variable()創建的變量不起作用
BatchNormalization: 在每一個批次的數據中標準化前一層的激活項即,
應用一個維持激活項平均值接近0,標準方差接近1的轉換
這裏有個大坑:
tf.layers.BatchNormalization和tf.layers.batch_normalization會自動將
update_ops添加到tf.GraphKeys.UPDATE_OPS這個collection中,當training=True時纔會添加;
而tf.keras.layers.BatchNormalization不會自動將update_ops添加到tf.GraphKeys.UPDATE_OPS這個collection中。
所以在TensorFlow訓練session中使用tf.keras.layers.BatchNormalization時,
需要手動將keras.BatchNormal層的updates添加到tf.GraphKeys.UPDATE_OPS這個collection中。
在訓練時,要將BatchNormalization中的training參數設置爲True,測試時設置爲False,
在保存模型時,要將\mu和\delta保存,這兩個參數存放在tf.global_variables()中
var_list = tf.trainable_variables()
g_list = tf.global_variables()
bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name]
bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name]
var_list += bn_moving_vars
saver = tf.train.Saver(var_list=var_list, max_to_keep=5)
"""
# Batch Normaliaztion
# bn_1 = tf.layers.batch_normalization(x, training=self.is_training, name="bn_1")
bn_1 = k.layers.BatchNormalization(name="bn_1")(x, self.is_training)
# 1x1 Convolution
conv_1 = k.layers.Conv2D(c_out//4,
kernel_size=1,
strides=(1, 1),
padding=padding,
name="conv_1",
activation="relu",
activity_regularizer=layers.l2_regularizer(scale=scale, scope="conv_1_l2")
)(bn_1)
# Batch Normalization
bn_2 = k.layers.BatchNormalization(name="bn_2")(conv_1, self.is_training)
# 3x3 Convolution
conv_2 = k.layers.Conv2D(c_out//4,
kernel_size=3,
strides=(stride, stride),
padding=padding,
name="conv_2",
activation="relu",
activity_regularizer=layers.l2_regularizer(scale=scale, scope="conv_2_l2")
)(bn_2)
# Batch Normalization
bn_3 = k.layers.BatchNormalization(name="bn_3")(conv_2, self.is_training)
# 1x1 Convolution
# 不加激活函數
conv_3 = k.layers.Conv2D(c_out,
kernel_size=1,
strides=(1, 1),
padding=padding,
name="conv_3",
activation=None,
activity_regularizer=layers.l2_regularizer(scale=scale, scope="conv_3_l2")
)(bn_3)
# Skip connection
# 不加激活函數
if c_out != c_in or stride > 1:
skip = k.layers.Conv2D(c_out,
kernel_size=1,
strides=(stride, stride),
padding=padding,
name="conv_skip",
activation=None,
activity_regularizer=layers.l2_regularizer(scale=scale, scope="skip_l2")
)(x)
else:
skip = x
outputs = k.layers.Add(name="fuse")([conv_3, skip])
return outputs