resnet tensorflow 代碼分析

https://github.com/chaipangpang/ResNet_cifar參考

下面有詳細代碼分析:

main.py

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""ResNet Train/Eval module.
"""
import time
import six
import sys

import cifar_input
import numpy as np
import resnet_model
import tensorflow as tf


# FLAGS參數設置
FLAGS = tf.app.flags.FLAGS
# 數據集類型
tf.app.flags.DEFINE_string('dataset', 
                           'cifar10', 
                           'cifar10 or cifar100.')
# 模式:訓練、測試
tf.app.flags.DEFINE_string('mode', 
                           'train', 
                           'train or eval.')
# 訓練數據路徑
tf.app.flags.DEFINE_string('train_data_path', 
                           'data/cifar-10-batches-bin/data_batch*',
                           'Filepattern for training data.')
# 測試數據路勁
tf.app.flags.DEFINE_string('eval_data_path', 
                           'data/cifar-10-batches-bin/test_batch.bin',
                           'Filepattern for eval data')
# 圖片尺寸
tf.app.flags.DEFINE_integer('image_size', 
                            32, 
                            'Image side length.')
# 訓練過程數據的存放路勁
tf.app.flags.DEFINE_string('train_dir', 
                           'temp/train',
                           'Directory to keep training outputs.')
# 測試過程數據的存放路勁
tf.app.flags.DEFINE_string('eval_dir', 
                           'temp/eval',
                           'Directory to keep eval outputs.')
# 測試數據的Batch數量
tf.app.flags.DEFINE_integer('eval_batch_count', 
                            50,
                            'Number of batches to eval.')
# 一次性測試
tf.app.flags.DEFINE_bool('eval_once', 
                         False,
                         'Whether evaluate the model only once.')
# 模型存儲路勁
tf.app.flags.DEFINE_string('log_root', 
                           'temp',
                           'Directory to keep the checkpoints. Should be a '
                           'parent directory of FLAGS.train_dir/eval_dir.')
# GPU設備數量(0代表CPU)
tf.app.flags.DEFINE_integer('num_gpus', 
                            1,
                            'Number of gpus used for training. (0 or 1)')


def train(hps):
  # 構建輸入數據(讀取隊列執行器)
  images, labels = cifar_input.build_input(
      FLAGS.dataset, FLAGS.train_data_path, hps.batch_size, FLAGS.mode)
  # 構建殘差網絡模型
  model = resnet_model.ResNet(hps, images, labels, FLAGS.mode)
  model.build_graph()

  # 計算預測準確率
  truth = tf.argmax(model.labels, axis=1)
  predictions = tf.argmax(model.predictions, axis=1)
  precision = tf.reduce_mean(tf.to_float(tf.equal(predictions, truth)))

  # 建立總結存儲器,每100步存儲一次
  summary_hook = tf.train.SummarySaverHook(
              save_steps=100,
              output_dir=FLAGS.train_dir,
              summary_op=tf.summary.merge(
                              [model.summaries,
                               tf.summary.scalar('Precision', precision)]))
  # 建立日誌打印器,每100步打印一次
  logging_hook = tf.train.LoggingTensorHook(
      tensors={'step': model.global_step,
               'loss': model.cost,
               'precision': precision},
      every_n_iter=100)

  # 學習率更新器,基於全局Step
  class _LearningRateSetterHook(tf.train.SessionRunHook):

    def begin(self):
      #初始學習率
      self._lrn_rate = 0.1

    def before_run(self, run_context):
      return tf.train.SessionRunArgs(
                      # 獲取全局Step
                      model.global_step,
                      # 設置學習率
                      feed_dict={model.lrn_rate: self._lrn_rate})  

    def after_run(self, run_context, run_values):
      # 動態更新學習率
      train_step = run_values.results
      if train_step < 40000:
        self._lrn_rate = 0.1
      elif train_step < 60000:
        self._lrn_rate = 0.01
      elif train_step < 80000:
        self._lrn_rate = 0.001
      else:
        self._lrn_rate = 0.0001

  # 建立監控Session
  with tf.train.MonitoredTrainingSession(
      checkpoint_dir=FLAGS.log_root,
      hooks=[logging_hook, _LearningRateSetterHook()],
      chief_only_hooks=[summary_hook],
      # 禁用默認的SummarySaverHook,save_summaries_steps設置爲0
      save_summaries_steps=0, 
      config=tf.ConfigProto(allow_soft_placement=True)) as mon_sess:
    while not mon_sess.should_stop():
      # 執行優化訓練操作
      mon_sess.run(model.train_op)


def evaluate(hps):
  # 構建輸入數據(讀取隊列執行器)
  images, labels = cifar_input.build_input(
      FLAGS.dataset, FLAGS.eval_data_path, hps.batch_size, FLAGS.mode)
  # 構建殘差網絡模型
  model = resnet_model.ResNet(hps, images, labels, FLAGS.mode)
  model.build_graph()
  # 模型變量存儲器
  saver = tf.train.Saver()
  # 總結文件 生成器
  summary_writer = tf.summary.FileWriter(FLAGS.eval_dir)
  
  # 執行Session
  sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
  
  # 啓動所有隊列執行器
  tf.train.start_queue_runners(sess)

  best_precision = 0.0
  while True:
    # 檢查checkpoint文件
    try:
      ckpt_state = tf.train.get_checkpoint_state(FLAGS.log_root)
    except tf.errors.OutOfRangeError as e:
      tf.logging.error('Cannot restore checkpoint: %s', e)
      continue
    if not (ckpt_state and ckpt_state.model_checkpoint_path):
      tf.logging.info('No model to eval yet at %s', FLAGS.log_root)
      continue
  
    # 讀取模型數據(訓練期間生成)
    tf.logging.info('Loading checkpoint %s', ckpt_state.model_checkpoint_path)
    saver.restore(sess, ckpt_state.model_checkpoint_path)

    # 逐Batch執行測試
    total_prediction, correct_prediction = 0, 0
    for _ in six.moves.range(FLAGS.eval_batch_count):
      # 執行預測
      (loss, predictions, truth, train_step) = sess.run(
          [model.cost, model.predictions,
           model.labels, model.global_step])
      # 計算預測結果
      truth = np.argmax(truth, axis=1)
      predictions = np.argmax(predictions, axis=1)
      correct_prediction += np.sum(truth == predictions)
      total_prediction += predictions.shape[0]

    # 計算準確率
    precision = 1.0 * correct_prediction / total_prediction
    best_precision = max(precision, best_precision)

    # 添加準確率總結
    precision_summ = tf.Summary()
    precision_summ.value.add(
        tag='Precision', simple_value=precision)
    summary_writer.add_summary(precision_summ, train_step)
    
    # 添加最佳準確總結
    best_precision_summ = tf.Summary()
    best_precision_summ.value.add(
        tag='Best Precision', simple_value=best_precision)
    summary_writer.add_summary(best_precision_summ, train_step)
    
    # 添加測試總結
    #summary_writer.add_summary(summaries, train_step)
    
    # 打印日誌
    tf.logging.info('loss: %.3f, precision: %.3f, best precision: %.3f' %
                    (loss, precision, best_precision))
    
    # 執行寫文件
    summary_writer.flush()

    if FLAGS.eval_once:
      break

    time.sleep(60)


def main(_):
  # 設備選擇
  if FLAGS.num_gpus == 0:
    dev = '/cpu:0'
  elif FLAGS.num_gpus == 1:
    dev = '/gpu:0'
  else:
    raise ValueError('Only support 0 or 1 gpu.')
    
  # 執行模式
  if FLAGS.mode == 'train':
    batch_size = 128
  elif FLAGS.mode == 'eval':
    batch_size = 100

  # 數據集類別數量
  if FLAGS.dataset == 'cifar10':
    num_classes = 10
  elif FLAGS.dataset == 'cifar100':
    num_classes = 100

  # 殘差網絡模型參數
  hps = resnet_model.HParams(batch_size=batch_size,
                             num_classes=num_classes,
                             min_lrn_rate=0.0001,
                             lrn_rate=0.1,
                             num_residual_units=5,
                             use_bottleneck=False,
                             weight_decay_rate=0.0002,
                             relu_leakiness=0.1,
                             optimizer='mom')
  # 執行訓練或測試
  with tf.device(dev):
    if FLAGS.mode == 'train':
      train(hps)
    elif FLAGS.mode == 'eval':
      evaluate(hps)


if __name__ == '__main__':
  tf.logging.set_verbosity(tf.logging.INFO)
  tf.app.run()

cifar_input.py

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""CIFAR dataset input module.
"""

import tensorflow as tf

def build_input(dataset, data_path, batch_size, mode):
  """Build CIFAR image and labels.
  Args:
    dataset(數據集): Either 'cifar10' or 'cifar100'.
    data_path(數據集路徑): Filename for data.
    batch_size: Input batch size.
    mode(模式): Either 'train' or 'eval'.
  Returns:
    images(圖片): Batches of images. [batch_size, image_size, image_size, 3]
    labels(類別標籤): Batches of labels. [batch_size, num_classes]
  Raises:
    ValueError: when the specified dataset is not supported.
  """
  
  # 數據集參數
  image_size = 32
  if dataset == 'cifar10':
    label_bytes = 1
    label_offset = 0
    num_classes = 10
  elif dataset == 'cifar100':
    label_bytes = 1
    label_offset = 1
    num_classes = 100
  else:
    raise ValueError('Not supported dataset %s', dataset)

  # 數據讀取參數
  depth = 3
  image_bytes = image_size * image_size * depth
  record_bytes = label_bytes + label_offset + image_bytes

  # 獲取文件名列表
  data_files = tf.gfile.Glob(data_path)
  # 文件名列表生成器
  file_queue = tf.train.string_input_producer(data_files, shuffle=True)
  # 文件名列表裏讀取原始二進制數據
  reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
  _, value = reader.read(file_queue)

  # 將原始二進制數據轉換成圖片數據及類別標籤
  record = tf.reshape(tf.decode_raw(value, tf.uint8), [record_bytes])
  label = tf.cast(tf.slice(record, [label_offset], [label_bytes]), tf.int32)
  # 將數據串 [depth * height * width] 轉換成矩陣 [depth, height, width].
  depth_major = tf.reshape(tf.slice(record, [label_bytes], [image_bytes]),
                           [depth, image_size, image_size])
  # 轉換維數:[depth, height, width]轉成[height, width, depth].
  image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32)

  if mode == 'train':
    # 增減圖片尺寸
    image = tf.image.resize_image_with_crop_or_pad(
                        image, image_size+4, image_size+4)
    # 隨機裁剪圖片
    image = tf.random_crop(image, [image_size, image_size, 3])
    # 隨機水平翻轉圖片
    image = tf.image.random_flip_left_right(image)
    # 逐圖片做像素值中心化(減均值)
    image = tf.image.per_image_standardization(image)

    # 建立輸入數據隊列(隨機洗牌)
    example_queue = tf.RandomShuffleQueue(
        # 隊列容量
        capacity=16 * batch_size,
        # 隊列數據的最小容許量
        min_after_dequeue=8 * batch_size,
        dtypes=[tf.float32, tf.int32],
        # 圖片數據尺寸,標籤尺寸
        shapes=[[image_size, image_size, depth], [1]])
    # 讀線程的數量
    num_threads = 16
  else:
    # 獲取測試圖片,並做像素值中心化
    image = tf.image.resize_image_with_crop_or_pad(
                        image, image_size, image_size)
    image = tf.image.per_image_standardization(image)

    # 建立輸入數據隊列(先入先出隊列)
    example_queue = tf.FIFOQueue(
        3 * batch_size,
        dtypes=[tf.float32, tf.int32],
        shapes=[[image_size, image_size, depth], [1]])
    # 讀線程的數量
    num_threads = 1

  # 數據入隊操作
  example_enqueue_op = example_queue.enqueue([image, label])
  # 隊列執行器
  tf.train.add_queue_runner(tf.train.queue_runner.QueueRunner(
      example_queue, [example_enqueue_op] * num_threads))

  # 數據出隊操作,從隊列讀取Batch數據
  images, labels = example_queue.dequeue_many(batch_size)
  # 將標籤數據由稀疏格式轉換成稠密格式
  # [ 2,       [[0,1,0,0,0]
  #   4,        [0,0,0,1,0]  
  #   3,   -->  [0,0,1,0,0]    
  #   5,        [0,0,0,0,1]
  #   1 ]       [1,0,0,0,0]]
  labels = tf.reshape(labels, [batch_size, 1])
  indices = tf.reshape(tf.range(0, batch_size, 1), [batch_size, 1])
  labels = tf.sparse_to_dense(
                  tf.concat(values=[indices, labels], axis=1),
                  [batch_size, num_classes], 1.0, 0.0)

  #檢測數據維度
  assert len(images.get_shape()) == 4
  assert images.get_shape()[0] == batch_size
  assert images.get_shape()[-1] == 3
  assert len(labels.get_shape()) == 2
  assert labels.get_shape()[0] == batch_size
  assert labels.get_shape()[1] == num_classes

  # 添加圖片總結
  tf.summary.image('images', images)
  return images, labels

resnet_model.py

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""ResNet model.
Related papers:
https://arxiv.org/pdf/1603.05027v2.pdf
https://arxiv.org/pdf/1512.03385v1.pdf
https://arxiv.org/pdf/1605.07146v1.pdf
"""
from collections import namedtuple

import numpy as np
import tensorflow as tf
import six

from tensorflow.python.training import moving_averages


HParams = namedtuple('HParams',
                     'batch_size, num_classes, min_lrn_rate, lrn_rate, '
                     'num_residual_units, use_bottleneck, weight_decay_rate, '
                     'relu_leakiness, optimizer')


class ResNet(object):
  """ResNet model."""

  def __init__(self, hps, images, labels, mode):
    """ResNet constructor.
    Args:
      hps: Hyperparameters.
      images: Batches of images 圖片. [batch_size, image_size, image_size, 3]
      labels: Batches of labels 類別標籤. [batch_size, num_classes]
      mode: One of 'train' and 'eval'.
    """
    self.hps = hps
    self._images = images
    self.labels = labels
    self.mode = mode

    self._extra_train_ops = []

  # 構建模型圖
  def build_graph(self):
    # 新建全局step
    self.global_step = tf.contrib.framework.get_or_create_global_step()
    # 構建ResNet網絡模型
    self._build_model()
    # 構建優化訓練操作
    if self.mode == 'train':
      self._build_train_op()
    # 合併所有總結
    self.summaries = tf.summary.merge_all()


  # 構建模型
  def _build_model(self):
    with tf.variable_scope('init'):
      x = self._images
      """第一層卷積(3,3x3/1,16)"""
      x = self._conv('init_conv', x, 3, 3, 16, self._stride_arr(1))

    # 殘差網絡參數
    strides = [1, 2, 2]
    # 激活前置
    activate_before_residual = [True, False, False]
    if self.hps.use_bottleneck:
      # bottleneck殘差單元模塊
      res_func = self._bottleneck_residual
      # 通道數量
      filters = [16, 64, 128, 256]
    else:
      # 標準殘差單元模塊
      res_func = self._residual
      # 通道數量
      filters = [16, 16, 32, 64]

    # 第一組
    with tf.variable_scope('unit_1_0'):
      x = res_func(x, filters[0], filters[1], 
                   self._stride_arr(strides[0]),
                   activate_before_residual[0])
    for i in six.moves.range(1, self.hps.num_residual_units):
      with tf.variable_scope('unit_1_%d' % i):
        x = res_func(x, filters[1], filters[1], self._stride_arr(1), False)

    # 第二組
    with tf.variable_scope('unit_2_0'):
      x = res_func(x, filters[1], filters[2], 
                   self._stride_arr(strides[1]),
                   activate_before_residual[1])
    for i in six.moves.range(1, self.hps.num_residual_units):
      with tf.variable_scope('unit_2_%d' % i):
        x = res_func(x, filters[2], filters[2], self._stride_arr(1), False)
        
    # 第三組
    with tf.variable_scope('unit_3_0'):
      x = res_func(x, filters[2], filters[3], self._stride_arr(strides[2]),
                   activate_before_residual[2])
    for i in six.moves.range(1, self.hps.num_residual_units):
      with tf.variable_scope('unit_3_%d' % i):
        x = res_func(x, filters[3], filters[3], self._stride_arr(1), False)

    # 全局池化層
    with tf.variable_scope('unit_last'):
      x = self._batch_norm('final_bn', x)
      x = self._relu(x, self.hps.relu_leakiness)
      x = self._global_avg_pool(x)

    # 全連接層 + Softmax
    with tf.variable_scope('logit'):
      logits = self._fully_connected(x, self.hps.num_classes)
      self.predictions = tf.nn.softmax(logits)

    # 構建損失函數
    with tf.variable_scope('costs'):
      # 交叉熵
      xent = tf.nn.softmax_cross_entropy_with_logits(
          logits=logits, labels=self.labels)
      # 加和
      self.cost = tf.reduce_mean(xent, name='xent')
      # L2正則,權重衰減
      self.cost += self._decay()
      # 添加cost總結,用於Tensorborad顯示
      tf.summary.scalar('cost', self.cost)

  # 構建訓練操作
  def _build_train_op(self):
    # 學習率/步長
    self.lrn_rate = tf.constant(self.hps.lrn_rate, tf.float32)
    tf.summary.scalar('learning_rate', self.lrn_rate)

    # 計算訓練參數的梯度
    trainable_variables = tf.trainable_variables()
    grads = tf.gradients(self.cost, trainable_variables)

    # 設置優化方法
    if self.hps.optimizer == 'sgd':
      optimizer = tf.train.GradientDescentOptimizer(self.lrn_rate)
    elif self.hps.optimizer == 'mom':
      optimizer = tf.train.MomentumOptimizer(self.lrn_rate, 0.9)

    # 梯度優化操作
    apply_op = optimizer.apply_gradients(
                        zip(grads, trainable_variables),
                        global_step=self.global_step, 
                        name='train_step')
    
    # 合併BN更新操作
    train_ops = [apply_op] + self._extra_train_ops
    # 建立優化操作組
    self.train_op = tf.group(*train_ops)


  # 把步長值轉換成tf.nn.conv2d需要的步長數組
  def _stride_arr(self, stride):    
    return [1, stride, stride, 1]

  # 殘差單元模塊
  def _residual(self, x, in_filter, out_filter, stride, activate_before_residual=False):
    # 是否前置激活(取殘差直連之前進行BN和ReLU)
    if activate_before_residual:
      with tf.variable_scope('shared_activation'):
        # 先做BN和ReLU激活
        x = self._batch_norm('init_bn', x)
        x = self._relu(x, self.hps.relu_leakiness)
        # 獲取殘差直連
        orig_x = x
    else:
      with tf.variable_scope('residual_only_activation'):
        # 獲取殘差直連
        orig_x = x
        # 後做BN和ReLU激活
        x = self._batch_norm('init_bn', x)
        x = self._relu(x, self.hps.relu_leakiness)

    # 第1子層
    with tf.variable_scope('sub1'):
      # 3x3卷積,使用輸入步長,通道數(in_filter -> out_filter)
      x = self._conv('conv1', x, 3, in_filter, out_filter, stride)

    # 第2子層
    with tf.variable_scope('sub2'):
      # BN和ReLU激活
      x = self._batch_norm('bn2', x)
      x = self._relu(x, self.hps.relu_leakiness)
      # 3x3卷積,步長爲1,通道數不變(out_filter)
      x = self._conv('conv2', x, 3, out_filter, out_filter, [1, 1, 1, 1])
    
    # 合併殘差層
    with tf.variable_scope('sub_add'):
      # 當通道數有變化時
      if in_filter != out_filter:
        # 均值池化,無補零
        orig_x = tf.nn.avg_pool(orig_x, stride, stride, 'VALID')
        # 通道補零(第4維前後對稱補零)
        orig_x = tf.pad(orig_x, 
                        [[0, 0], 
                         [0, 0], 
                         [0, 0],
                         [(out_filter-in_filter)//2, (out_filter-in_filter)//2]
                        ])
      # 合併殘差
      x += orig_x

    tf.logging.debug('image after unit %s', x.get_shape())
    return x

  # bottleneck殘差單元模塊
  def _bottleneck_residual(self, x, in_filter, out_filter, stride,
                           activate_before_residual=False):
    # 是否前置激活(取殘差直連之前進行BN和ReLU)
    if activate_before_residual:
      with tf.variable_scope('common_bn_relu'):
        # 先做BN和ReLU激活
        x = self._batch_norm('init_bn', x)
        x = self._relu(x, self.hps.relu_leakiness)
        # 獲取殘差直連
        orig_x = x
    else:
      with tf.variable_scope('residual_bn_relu'):
        # 獲取殘差直連
        orig_x = x
        # 後做BN和ReLU激活
        x = self._batch_norm('init_bn', x)
        x = self._relu(x, self.hps.relu_leakiness)

    # 第1子層
    with tf.variable_scope('sub1'):
      # 1x1卷積,使用輸入步長,通道數(in_filter -> out_filter/4)
      x = self._conv('conv1', x, 1, in_filter, out_filter/4, stride)

    # 第2子層
    with tf.variable_scope('sub2'):
      # BN和ReLU激活
      x = self._batch_norm('bn2', x)
      x = self._relu(x, self.hps.relu_leakiness)
      # 3x3卷積,步長爲1,通道數不變(out_filter/4)
      x = self._conv('conv2', x, 3, out_filter/4, out_filter/4, [1, 1, 1, 1])

    # 第3子層
    with tf.variable_scope('sub3'):
      # BN和ReLU激活
      x = self._batch_norm('bn3', x)
      x = self._relu(x, self.hps.relu_leakiness)
      # 1x1卷積,步長爲1,通道數不變(out_filter/4 -> out_filter)
      x = self._conv('conv3', x, 1, out_filter/4, out_filter, [1, 1, 1, 1])

    # 合併殘差層
    with tf.variable_scope('sub_add'):
      # 當通道數有變化時
      if in_filter != out_filter:
        # 1x1卷積,使用輸入步長,通道數(in_filter -> out_filter)
        orig_x = self._conv('project', orig_x, 1, in_filter, out_filter, stride)
      
      # 合併殘差
      x += orig_x

    tf.logging.info('image after unit %s', x.get_shape())
    return x


  # Batch Normalization批歸一化
  # ((x-mean)/var)*gamma+beta
  def _batch_norm(self, name, x):
    with tf.variable_scope(name):
      # 輸入通道維數
      params_shape = [x.get_shape()[-1]]
      # offset
      beta = tf.get_variable('beta', 
                             params_shape, 
                             tf.float32,
                             initializer=tf.constant_initializer(0.0, tf.float32))
      # scale
      gamma = tf.get_variable('gamma', 
                              params_shape, 
                              tf.float32,
                              initializer=tf.constant_initializer(1.0, tf.float32))

      if self.mode == 'train':
        # 爲每個通道計算均值、標準差
        mean, variance = tf.nn.moments(x, [0, 1, 2], name='moments')
        # 新建或建立測試階段使用的batch均值、標準差
        moving_mean = tf.get_variable('moving_mean', 
                                      params_shape, tf.float32,
                                      initializer=tf.constant_initializer(0.0, tf.float32),
                                      trainable=False)
        moving_variance = tf.get_variable('moving_variance', 
                                          params_shape, tf.float32,
                                          initializer=tf.constant_initializer(1.0, tf.float32),
                                          trainable=False)
        # 添加batch均值和標準差的更新操作(滑動平均)
        # moving_mean = moving_mean * decay + mean * (1 - decay)
        # moving_variance = moving_variance * decay + variance * (1 - decay)
        self._extra_train_ops.append(moving_averages.assign_moving_average(
                                                        moving_mean, mean, 0.9))
        self._extra_train_ops.append(moving_averages.assign_moving_average(
                                                        moving_variance, variance, 0.9))
      else:
        # 獲取訓練中積累的batch均值、標準差
        mean = tf.get_variable('moving_mean', 
                               params_shape, tf.float32,
                               initializer=tf.constant_initializer(0.0, tf.float32),
                               trainable=False)
        variance = tf.get_variable('moving_variance', 
                                   params_shape, tf.float32,
                                   initializer=tf.constant_initializer(1.0, tf.float32),
                                   trainable=False)
        # 添加到直方圖總結
        tf.summary.histogram(mean.op.name, mean)
        tf.summary.histogram(variance.op.name, variance)

      # BN層:((x-mean)/var)*gamma+beta
      y = tf.nn.batch_normalization(x, mean, variance, beta, gamma, 0.001)
      y.set_shape(x.get_shape())
      return y


  # 權重衰減,L2正則loss
  def _decay(self):
    costs = []
    # 遍歷所有可訓練變量
    for var in tf.trainable_variables():
      #只計算標有“DW”的變量
      if var.op.name.find(r'DW') > 0:
        costs.append(tf.nn.l2_loss(var))
    # 加和,並乘以衰減因子
    return tf.multiply(self.hps.weight_decay_rate, tf.add_n(costs))

  # 2D卷積
  def _conv(self, name, x, filter_size, in_filters, out_filters, strides):
    with tf.variable_scope(name):
      n = filter_size * filter_size * out_filters
      # 獲取或新建卷積核,正態隨機初始化
      kernel = tf.get_variable(
              'DW', 
              [filter_size, filter_size, in_filters, out_filters],
              tf.float32, 
              initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0/n)))
      # 計算卷積
      return tf.nn.conv2d(x, kernel, strides, padding='SAME')

  # leaky ReLU激活函數,泄漏參數leakiness爲0就是標準ReLU
  def _relu(self, x, leakiness=0.0):
    return tf.where(tf.less(x, 0.0), leakiness * x, x, name='leaky_relu')
  
  # 全連接層,網絡最後一層
  def _fully_connected(self, x, out_dim):
    # 輸入轉換成2D tensor,尺寸爲[N,-1]
    x = tf.reshape(x, [self.hps.batch_size, -1])
    # 參數w,平均隨機初始化,[-sqrt(3/dim), sqrt(3/dim)]*factor
    w = tf.get_variable('DW', [x.get_shape()[1], out_dim],
                        initializer=tf.uniform_unit_scaling_initializer(factor=1.0))
    # 參數b,0值初始化
    b = tf.get_variable('biases', [out_dim], initializer=tf.constant_initializer())
    # 計算x*w+b
    return tf.nn.xw_plus_b(x, w, b)

  # 全局均值池化
  def _global_avg_pool(self, x):
    assert x.get_shape().ndims == 4
    # 在第2&3維度上計算均值,尺寸由WxH收縮爲1x1
    return tf.reduce_mean(x, [1, 2])

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章