TF學習之DeepLabv3+代碼閱讀1(train)

DeepLabv3+代碼閱讀之train.py

一、main()

def main(unused_argv):# main必須帶參數,否則報錯:'TypeError: main() takes no arguments (1 given)'; 
                      # main的參數名隨意定義,無要求。
  tf.logging.set_verbosity(tf.logging.INFO) # 把日誌設置在INFO級別
  tf.gfile.MakeDirs(FLAGS.train_logdir)
  tf.logging.info('Training on %s set', FLAGS.train_split)

  graph = tf.Graph()# tf.Graph()表示實例化了一個類,一個用於tf計算和表示用的數據流圖,就是呈現計算圖(節點和線,操作和數據)
  					# 的“紙”
  with graph.as_default():# 表示將這個類實例(graph,新生成的圖)作爲整個tf運行環境的默認圖
  						  # 同時tf裏面也已經存好的一張默認圖,通過tf.get_default_graph()來調用(顯示這張默認紙)
    with tf.device(tf.train.replica_device_setter(ps_tasks=FLAGS.num_ps_tasks)):# FLAGS.num_ps_tasks=0
      assert FLAGS.train_batch_size % FLAGS.num_clones == 0, (
          'Training batch size not divisble by number of clones (GPUs).')
      clone_batch_size = FLAGS.train_batch_size // FLAGS.num_clones

      dataset = data_generator.Dataset(
          dataset_name=FLAGS.dataset,# pascal_voc_seg
          split_name=FLAGS.train_split,# train
          dataset_dir=FLAGS.dataset_dir,
          batch_size=clone_batch_size,
          crop_size=[int(sz) for sz in FLAGS.train_crop_size],# FLAGS.train_crop_size = ['513', '513']
          													  # crop_size = [513, 513]
          min_resize_value=FLAGS.min_resize_value,
          max_resize_value=FLAGS.max_resize_value,
          resize_factor=FLAGS.resize_factor,
          min_scale_factor=FLAGS.min_scale_factor,
          max_scale_factor=FLAGS.max_scale_factor,
          scale_factor_step_size=FLAGS.scale_factor_step_size,
          model_variant=FLAGS.model_variant,
          num_readers=2,
          is_training=True,
          should_shuffle=True,
          should_repeat=True)

      train_tensor, summary_op = _train_deeplab_model(
          dataset.get_one_shot_iterator(), dataset.num_of_classes,
          dataset.ignore_label)

      # Soft placement allows placing on CPU ops without GPU implementation.
      session_config = tf.ConfigProto(
          allow_soft_placement=True, log_device_placement=False)

      last_layers = model.get_extra_layer_scopes(
          FLAGS.last_layers_contain_logits_only)
      init_fn = None
      if FLAGS.tf_initial_checkpoint:
        init_fn = train_utils.get_model_init_fn(
            FLAGS.train_logdir,
            FLAGS.tf_initial_checkpoint,
            FLAGS.initialize_last_layer,
            last_layers,
            ignore_missing_vars=True)

      scaffold = tf.train.Scaffold(
          init_fn=init_fn,
          summary_op=summary_op,
      )

      stop_hook = tf.train.StopAtStepHook(
          last_step=FLAGS.training_number_of_steps)

      profile_dir = FLAGS.profile_logdir
      if profile_dir is not None:
        tf.gfile.MakeDirs(profile_dir)

      with tf.contrib.tfprof.ProfileContext(
          enabled=profile_dir is not None, profile_dir=profile_dir):
        with tf.train.MonitoredTrainingSession(
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            config=session_config,
            scaffold=scaffold,
            checkpoint_dir=FLAGS.train_logdir,
            summary_dir=FLAGS.train_logdir,
            log_step_count_steps=FLAGS.log_steps,
            save_summaries_steps=FLAGS.save_summaries_secs,
            save_checkpoint_secs=FLAGS.save_interval_secs,
            hooks=[stop_hook]) as sess:
          while not sess.should_stop():
            sess.run([train_tensor])

二、_train_deeplab_model()

def _train_deeplab_model(iterator, num_of_classes, ignore_label):
  """Trains the deeplab model.

  Args:
    iterator: An iterator of type tf.data.Iterator for images and labels.
    num_of_classes: Number of classes for the dataset.
    ignore_label: Ignore label for the dataset.

  Returns:
    train_tensor: A tensor to update the model variables.
    summary_op: An operation to log the summaries.
  """
  global_step = tf.train.get_or_create_global_step()# 全局步數,在optimizer.minize()時,傳入global_step,
  													# sess每執行完一個batch,就給global_step加1,
  													# 用來統計目前執行的batch數
  learning_rate = train_utils.get_model_learning_rate(
      FLAGS.learning_policy, FLAGS.base_learning_rate,# 'poly', .0001
      FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,# 2000, 0.1
      FLAGS.training_number_of_steps, FLAGS.learning_power,# 30000, 0.9
      FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)# 0, 1e-4
  tf.summary.scalar('learning_rate', learning_rate) # 在tensorboard上畫出

  optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)# 優化器

  tower_losses = []
  tower_grads = []
  for i in range(FLAGS.num_clones):# 用幾塊gpu來一起運行程序
    with tf.device('/gpu:%d' % i):# 指定程序在哪一塊gpu上運行,這裏num_clones=1表示只在一塊gpu運行程序,所以對應i=0,
    							  # 對應tf.device('/gpu:%d' % i)爲tf.device('/gpu:0')
      # First tower has default name scope.
      name_scope = ('clone_%d' % i) if i else ''# 命名空間,i=0時name_scope=‘’,i=1,2,3...時用'clone_1,2,3'表示
      with tf.name_scope(name_scope) as scope:
        loss = _tower_loss(
            iterator=iterator,
            num_of_classes=num_of_classes,
            ignore_label=ignore_label,
            scope=scope,
            reuse_variable=(i != 0))
        tower_losses.append(loss)

  if FLAGS.quantize_delay_step >= 0:
    if FLAGS.num_clones > 1:
      raise ValueError('Quantization doesn\'t support multi-clone yet.')
    tf.contrib.quantize.create_training_graph(
        quant_delay=FLAGS.quantize_delay_step)

  for i in range(FLAGS.num_clones):
    with tf.device('/gpu:%d' % i):
      name_scope = ('clone_%d' % i) if i else ''
      with tf.name_scope(name_scope) as scope:
        grads = optimizer.compute_gradients(tower_losses[i])# 計算梯度
        tower_grads.append(grads)

  with tf.device('/cpu:0'):
    grads_and_vars = _average_gradients(tower_grads)

    # Modify the gradients for biases and last layer variables.
    last_layers = model.get_extra_layer_scopes(
        FLAGS.last_layers_contain_logits_only)
    grad_mult = train_utils.get_model_gradient_multipliers(
        last_layers, FLAGS.last_layer_gradient_multiplier)
    if grad_mult:
      grads_and_vars = tf.contrib.training.multiply_gradients(
          grads_and_vars, grad_mult)

    # Create gradient update op.
    grad_updates = optimizer.apply_gradients(
        grads_and_vars, global_step=global_step)

    # Gather update_ops. These contain, for example,
    # the updates for the batch_norm variables created by model_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    update_ops.append(grad_updates)
    update_op = tf.group(*update_ops)

    total_loss = tf.losses.get_total_loss(add_regularization_losses=True)

    # Print total loss to the terminal.
    # This implementation is mirrored from tf.slim.summaries.
    should_log = math_ops.equal(math_ops.mod(global_step, FLAGS.log_steps), 0)
    total_loss = tf.cond(
        should_log,
        lambda: tf.Print(total_loss, [total_loss], 'Total loss is :'),
        lambda: total_loss)

    tf.summary.scalar('total_loss', total_loss)
    with tf.control_dependencies([update_op]):
      train_tensor = tf.identity(total_loss, name='train_op')

    # Excludes summaries from towers other than the first one.
    summary_op = tf.summary.merge_all(scope='(?!clone_)')

  return train_tensor, summary_op

三、_tower_loss

在一個tower上計算總的loss
輸入參數:
	iterator: tf.data.Iterator類的迭代器,輸入images和labels.
	num_of_classes: 類別數
	ignore_label: 忽略標籤編號
	scope: Unique prefix string identifying the deeplab tower.
	reuse_variable: If the variable should be reused.
返回:
	一個batch的數據上的總loss
def _tower_loss(iterator, num_of_classes, ignore_label, scope, reuse_variable):
	# tf.get_variable_scope()返回當前的變量空間,(根據引用時的參數)i=0時reuse=None,創建新的變量,若同名變量則報錯
  with tf.variable_scope(tf.get_variable_scope(), reuse=True if reuse_variable else None):
    _build_deeplab(iterator, {common.OUTPUT_TYPE: num_of_classes}, ignore_label)
    
  losses = tf.losses.get_losses(scope=scope)
  for loss in losses:
    tf.summary.scalar('Losses/%s' % loss.op.name, loss)

  regularization_loss = tf.losses.get_regularization_loss(scope=scope)
  tf.summary.scalar('Losses/%s' % regularization_loss.op.name,
                    regularization_loss)

  total_loss = tf.add_n([tf.add_n(losses), regularization_loss])
  return total_loss

四、_build_deeplab

創建DeepLab
輸入參數: 
	iterator: images和labels的迭代器
	outputs_to_num_classes: 輸出類別數,如:outputs_to_num_classes['semantic'] = 21
	ignore_label: 忽略的標籤編號
def _build_deeplab(iterator, outputs_to_num_classes, ignore_label):

  samples = iterator.get_next()

  # Add name to input and label nodes so we can add to summary.
  samples[common.IMAGE] = tf.identity(samples[common.IMAGE], name=common.IMAGE)
  samples[common.LABEL] = tf.identity(samples[common.LABEL], name=common.LABEL)

  model_options = common.ModelOptions(
      outputs_to_num_classes=outputs_to_num_classes,
      crop_size=[int(sz) for sz in FLAGS.train_crop_size],
      atrous_rates=FLAGS.atrous_rates,
      output_stride=FLAGS.output_stride)

  outputs_to_scales_to_logits = model.multi_scale_logits(
      samples[common.IMAGE],
      model_options=model_options,
      image_pyramid=FLAGS.image_pyramid,
      weight_decay=FLAGS.weight_decay,
      is_training=True,
      fine_tune_batch_norm=FLAGS.fine_tune_batch_norm,
      nas_training_hyper_parameters={
          'drop_path_keep_prob': FLAGS.drop_path_keep_prob,
          'total_training_steps': FLAGS.training_number_of_steps,
      })

  # Add name to graph node so we can add to summary.
  output_type_dict = outputs_to_scales_to_logits[common.OUTPUT_TYPE]# 把模型輸出得到的logits取出來
  output_type_dict[model.MERGED_LOGITS_SCOPE] = tf.identity(
      output_type_dict[model.MERGED_LOGITS_SCOPE], name=common.OUTPUT_TYPE)# 加個名字
  # output_type_dict['merged_logits'], name='semantic'
  # six是爲了解決Python2和Python3代碼兼容性而產生的,six=2*3
  # python2中six.iteritems返回的是生成器,當dict很大時,six.iteritems不佔用內存
  # python3裏面,dict.items改變了默認實現,也返回生成器,因此six.iteritems就沒啥用了
  for output, num_classes in six.iteritems(outputs_to_num_classes):# output='semantic',num_classes=21
    train_utils.add_softmax_cross_entropy_loss_for_each_scale(
        outputs_to_scales_to_logits[output],
        samples[common.LABEL],
        num_classes,
        ignore_label,
        loss_weight=1.0,
        upsample_logits=FLAGS.upsample_logits,
        hard_example_mining_step=FLAGS.hard_example_mining_step,
        top_k_percent_pixels=FLAGS.top_k_percent_pixels,
        scope=output)

    # Log the summary
    _log_summaries(samples[common.IMAGE], samples[common.LABEL], num_classes,
                   output_type_dict[model.MERGED_LOGITS_SCOPE])

五、_log_summaries

Logs the summaries for the model
參數:
	input_image: 輸入圖片,shape: [batch_size, height, width, channel].
	label: 標籤,shape: [batch_size, height, width].
	num_of_classes: 類別數
	output: 模型的輸出,shape: [batch_size, height, width]
def _log_summaries(input_image, label, num_of_classes, output):
  # Add summaries for model variables.
  for model_var in tf.model_variables():
    tf.summary.histogram(model_var.op.name, model_var)

  # Add summaries for images, labels, semantic predictions.
  if FLAGS.save_summaries_images:# 保存輸入圖片、標籤、預測圖像到summary
    tf.summary.image('samples/%s' % common.IMAGE, input_image)

    # Scale up summary image pixel values for better visualization.
    pixel_scaling = max(1, 255 // num_of_classes)
    summary_label = tf.cast(label * pixel_scaling, tf.uint8)
    tf.summary.image('samples/%s' % common.LABEL, summary_label)

    predictions = tf.expand_dims(tf.argmax(output, 3), -1)
    summary_predictions = tf.cast(predictions * pixel_scaling, tf.uint8)
    tf.summary.image('samples/%s' % common.OUTPUT_TYPE, summary_predictions)

六、_average_gradients

計算平均梯度
參數:
	tower_grads: List of lists of (gradient, variable) tuples.
返回:
	List of pairs of (gradient, variable) where the gradient has been summed across all towers.
def _average_gradients(tower_grads):

  average_grads = []
  for grad_and_vars in zip(*tower_grads):
    # Note that each grad_and_vars looks like the following:
    #   ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
    grads, variables = zip(*grad_and_vars)
    grad = tf.reduce_mean(tf.stack(grads, axis=0), axis=0)

    # All vars are of the same value, using the first tower here.
    average_grads.append((grad, variables[0]))

  return average_grads
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章