TF學習之DeepLabv3+代碼閱讀6(train_utils)

DeepLabv3+代碼閱讀之train_utils.py

一、get_model_learning_rate()

def get_model_learning_rate(learning_policy,# Learning rate policy for training.
                            base_learning_rate,# The base learning rate for model training.
                            learning_rate_decay_step, # Decay the base learning rate at a fixed step.
                            learning_rate_decay_factor,# The rate to decay the base learning rate.
                            training_number_of_steps,# Number of steps for training.
                            learning_power,# Power used for 'poly' learning policy.
                            slow_start_step,# Training model with small learning rate for the 
                            				# first few steps.
                            slow_start_learning_rate,# The learning rate employed during slow start.
                            slow_start_burnin_type='none'):# The burnin type for the slow start stage. Can be
      													   #`none` which means no burnin or `linear` which 
      													   # means the learning rate increases linearly from 
      													   # slow_start_learning_rate and reaches
      													   # base_learning_rate after slow_start_steps.
  """Gets model's learning rate.

  Computes the model's learning rate for different learning policy.
  Right now, only "step" and "poly" are supported.
  (1) The learning policy for "step" is computed as follows:
    current_learning_rate = base_learning_rate *
      learning_rate_decay_factor ^ (global_step / learning_rate_decay_step)
  See tf.train.exponential_decay for details.
  (2) The learning policy for "poly" is computed as follows:
    current_learning_rate = base_learning_rate *
      (1 - global_step / training_number_of_steps) ^ learning_power

  """
  global_step = tf.train.get_or_create_global_step()
  adjusted_global_step = global_step

  if slow_start_burnin_type != 'none':
    adjusted_global_step -= slow_start_step

  if learning_policy == 'step':
    learning_rate = tf.train.exponential_decay(
        base_learning_rate,
        adjusted_global_step,
        learning_rate_decay_step,
        learning_rate_decay_factor,
        staircase=True)
  elif learning_policy == 'poly':
    learning_rate = tf.train.polynomial_decay(
        base_learning_rate,
        adjusted_global_step,
        training_number_of_steps,
        end_learning_rate=0,
        power=learning_power)
  else:
    raise ValueError('Unknown learning policy.')

  adjusted_slow_start_learning_rate = slow_start_learning_rate
  if slow_start_burnin_type == 'linear':
    # Do linear burnin. Increase linearly from slow_start_learning_rate and
    # reach base_learning_rate after (global_step >= slow_start_steps).
    adjusted_slow_start_learning_rate = (
        slow_start_learning_rate +
        (base_learning_rate - slow_start_learning_rate) *
        tf.to_float(global_step) / slow_start_step)
  elif slow_start_burnin_type != 'none':
    raise ValueError('Unknown burnin type.')

  # Employ small learning rate at the first few steps for warm start.
  return tf.where(global_step < slow_start_step,
                  adjusted_slow_start_learning_rate, learning_rate)

二、add_softmax_cross_entropy_loss_for_each_scale

對每一個尺度的輸出結果計算cross entropy loss
參數:
	scales_to_logits: logits名字到不同尺度的輸出的對應,shape: [batch, logits_height, logits_width, num_classes].
	labels: Groundtruth labels, shape: [batch, image_height, image_width, 1].
	num_classes: 類別數
	ignore_label: 忽略的標籤編號
	loss_weight: loss的權重(=1.0)
	upsample_logits: 是否對logits上採樣
	hard_example_mining_step: default is 0
	top_k_percent_pixels: default is 0
	scope: the scope for the loss.
def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
                                                  labels,
                                                  num_classes,
                                                  ignore_label,
                                                  loss_weight=1.0,
                                                  upsample_logits=True,
                                                  hard_example_mining_step=0,
                                                  top_k_percent_pixels=1.0,
                                                  scope=None):

  if labels is None:
    raise ValueError('No label for softmax cross entropy loss.')
  # outputs_to_scales_to_logits = {k: {} for k in model_options.outputs_to_num_classes}
  # model_options.outputs_to_num_classes = {'semantic':21}
  # outputs_to_scales_to_logits = {'semantic': {'merged_logits': {}}}
  # scales_to_logits = outputs_to_scales_to_logits['semantic'] = {'merged_logits': {}}
  for scale, logits in six.iteritems(scales_to_logits):
    loss_scope = None
    if scope:
      loss_scope = '%s_%s' % (scope, scale)# 'semantic_merged_logits'

    if upsample_logits:
      # Label is not downsampled, and instead we upsample logits.上採樣logits,而不是下采樣label
      logits = tf.image.resize_bilinear(# 上採樣logits用bilinear插值
          logits,
          preprocess_utils.resolve_shape(labels, 4)[1:3],
          align_corners=True)
      scaled_labels = labels
    else:
      # Label is downsampled to the same size as logits.下采樣label
      scaled_labels = tf.image.resize_nearest_neighbor(# 下采樣label則使用nearest插值
          labels,
          preprocess_utils.resolve_shape(logits, 4)[1:3],
          align_corners=True)
	# 插值算法不包含batch維度,resize之後再加上batch維度
    scaled_labels = tf.reshape(scaled_labels, shape=[-1])
    not_ignore_mask = tf.to_float(tf.not_equal(scaled_labels,			   # not_equal(x,y)返回x!=y元素的真值
                                               ignore_label)) * loss_weight# 提取出label中不被忽略的像素位置
                                               							   # mask並乘權重
    one_hot_labels = tf.one_hot(
        scaled_labels, num_classes, on_value=1.0, off_value=0.0)# 變成one hot label

    if top_k_percent_pixels == 1.0:
      # Compute the loss for all pixels.
      tf.losses.softmax_cross_entropy(
          one_hot_labels,
          tf.reshape(logits, shape=[-1, num_classes]),
          weights=not_ignore_mask,
          scope=loss_scope)# loss_scope
    else:
      logits = tf.reshape(logits, shape=[-1, num_classes])
      weights = not_ignore_mask
      with tf.name_scope(loss_scope, 'softmax_hard_example_mining',
                         [logits, one_hot_labels, weights]):
        one_hot_labels = tf.stop_gradient(
            one_hot_labels, name='labels_stop_gradient')
        pixel_losses = tf.nn.softmax_cross_entropy_with_logits_v2(
            labels=one_hot_labels,
            logits=logits,
            name='pixel_losses')
        weighted_pixel_losses = tf.multiply(pixel_losses, weights)
        num_pixels = tf.to_float(tf.shape(logits)[0])
        # Compute the top_k_percent pixels based on current training step.
        if hard_example_mining_step == 0:
          # Directly focus on the top_k pixels.
          top_k_pixels = tf.to_int32(top_k_percent_pixels * num_pixels)
        else:
          # Gradually reduce the mining percent to top_k_percent_pixels.
          global_step = tf.to_float(tf.train.get_or_create_global_step())
          ratio = tf.minimum(1.0, global_step / hard_example_mining_step)
          top_k_pixels = tf.to_int32(
              (ratio * top_k_percent_pixels + (1.0 - ratio)) * num_pixels)
        top_k_losses, _ = tf.nn.top_k(weighted_pixel_losses,
                                      k=top_k_pixels,
                                      sorted=True,
                                      name='top_k_percent_pixels')
        total_loss = tf.reduce_sum(top_k_losses)
        num_present = tf.reduce_sum(
            tf.to_float(tf.not_equal(top_k_losses, 0.0)))
        loss = _div_maybe_zero(total_loss, num_present)
        tf.losses.add_loss(loss)

三、get_model_gradient_multipliers

梯度乘法器爲模型的變量調整學習率。對於分割任務,模型通常會從由訓練圖像分類任務得到的模型中進行微調。
我們通常會對最後一層選取大一些(例如10倍)的學習率。
參數:
	last_layers: 最後一層的域
	last_layer_gradient_multiplier:最後一層的梯度乘法器
返回:
	梯度乘法器的一個映射,{變量:乘法器的值}
def get_model_gradient_multipliers(last_layers, last_layer_gradient_multiplier):

  gradient_multipliers = {}

  for var in tf.model_variables():
    # Double the learning rate for biases.
    if 'biases' in var.op.name:
      gradient_multipliers[var.op.name] = 2.

    # Use larger learning rate for last layer variables.
    for layer in last_layers:
      if layer in var.op.name and 'biases' in var.op.name:
        gradient_multipliers[var.op.name] = 2 * last_layer_gradient_multiplier
        break
      elif layer in var.op.name:
        gradient_multipliers[var.op.name] = last_layer_gradient_multiplier
        break

  return gradient_multipliers

四、get_model_init_fn

從checkpoint中初始化模型。
參數:
	train_logdir: 儲存訓練過程的log和checkpoint文件目錄
	tf_initial_checkpoint: 用來初始化的checkpoint
	initialize_last_layer: 是否初始化最後一層
	last_layers: 模型的最後一層
	ignore_missing_vars: 忽略checkpoint中沒有的變量
返回:
	初始化後的模型
def get_model_init_fn(train_logdir,
                      tf_initial_checkpoint,
                      initialize_last_layer,
                      last_layers,
                      ignore_missing_vars=False):

  if tf_initial_checkpoint is None:
    tf.logging.info('Not initializing the model from a checkpoint.')
    return None

  if tf.train.latest_checkpoint(train_logdir):# 找到latest保存的checkpoint文件
    tf.logging.info('Ignoring initialization; other checkpoint exists')
    return None

  tf.logging.info('Initializing model from path: %s', tf_initial_checkpoint)

  # Variables that will not be restored.
  exclude_list = ['global_step']
  if not initialize_last_layer:
    exclude_list.extend(last_layers)

  variables_to_restore = tf.contrib.framework.get_variables_to_restore(
      exclude=exclude_list)

  if variables_to_restore:
    init_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint(
        tf_initial_checkpoint,
        variables_to_restore,
        ignore_missing_vars=ignore_missing_vars)
    global_step = tf.train.get_or_create_global_step()

    def restore_fn(unused_scaffold, sess):
      sess.run(init_op, init_feed_dict)
      sess.run([global_step])

    return restore_fn

  return None
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章