def_train_deeplab_model(iterator, num_of_classes, ignore_label):"""Trains the deeplab model.
Args:
iterator: An iterator of type tf.data.Iterator for images and labels.
num_of_classes: Number of classes for the dataset.
ignore_label: Ignore label for the dataset.
Returns:
train_tensor: A tensor to update the model variables.
summary_op: An operation to log the summaries.
"""
global_step = tf.train.get_or_create_global_step()# 全局步数,在optimizer.minize()时,传入global_step,# sess每执行完一个batch,就给global_step加1,# 用来统计目前执行的batch数
learning_rate = train_utils.get_model_learning_rate(
FLAGS.learning_policy, FLAGS.base_learning_rate,# 'poly', .0001
FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,# 2000, 0.1
FLAGS.training_number_of_steps, FLAGS.learning_power,# 30000, 0.9
FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)# 0, 1e-4
tf.summary.scalar('learning_rate', learning_rate)# 在tensorboard上画出
optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)# 优化器
tower_losses =[]
tower_grads =[]for i inrange(FLAGS.num_clones):# 用几块gpu来一起运行程序with tf.device('/gpu:%d'% i):# 指定程序在哪一块gpu上运行,这里num_clones=1表示只在一块gpu运行程序,所以对应i=0,# 对应tf.device('/gpu:%d' % i)为tf.device('/gpu:0')# First tower has default name scope.
name_scope =('clone_%d'% i)if i else''# 命名空间,i=0时name_scope=‘’,i=1,2,3...时用'clone_1,2,3'表示with tf.name_scope(name_scope)as scope:
loss = _tower_loss(
iterator=iterator,
num_of_classes=num_of_classes,
ignore_label=ignore_label,
scope=scope,
reuse_variable=(i !=0))
tower_losses.append(loss)if FLAGS.quantize_delay_step >=0:if FLAGS.num_clones >1:raise ValueError('Quantization doesn\'t support multi-clone yet.')
tf.contrib.quantize.create_training_graph(
quant_delay=FLAGS.quantize_delay_step)for i inrange(FLAGS.num_clones):with tf.device('/gpu:%d'% i):
name_scope =('clone_%d'% i)if i else''with tf.name_scope(name_scope)as scope:
grads = optimizer.compute_gradients(tower_losses[i])# 计算梯度
tower_grads.append(grads)with tf.device('/cpu:0'):
grads_and_vars = _average_gradients(tower_grads)# Modify the gradients for biases and last layer variables.
last_layers = model.get_extra_layer_scopes(
FLAGS.last_layers_contain_logits_only)
grad_mult = train_utils.get_model_gradient_multipliers(
last_layers, FLAGS.last_layer_gradient_multiplier)if grad_mult:
grads_and_vars = tf.contrib.training.multiply_gradients(
grads_and_vars, grad_mult)# Create gradient update op.
grad_updates = optimizer.apply_gradients(
grads_and_vars, global_step=global_step)# Gather update_ops. These contain, for example,# the updates for the batch_norm variables created by model_fn.
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
update_ops.append(grad_updates)
update_op = tf.group(*update_ops)
total_loss = tf.losses.get_total_loss(add_regularization_losses=True)# Print total loss to the terminal.# This implementation is mirrored from tf.slim.summaries.
should_log = math_ops.equal(math_ops.mod(global_step, FLAGS.log_steps),0)
total_loss = tf.cond(
should_log,lambda: tf.Print(total_loss,[total_loss],'Total loss is :'),lambda: total_loss)
tf.summary.scalar('total_loss', total_loss)with tf.control_dependencies([update_op]):
train_tensor = tf.identity(total_loss, name='train_op')# Excludes summaries from towers other than the first one.
summary_op = tf.summary.merge_all(scope='(?!clone_)')return train_tensor, summary_op
三、_tower_loss
在一个tower上计算总的loss
输入参数:
iterator: tf.data.Iterator类的迭代器,输入images和labels.
num_of_classes: 类别数
ignore_label: 忽略标签编号
scope: Unique prefix string identifying the deeplab tower.
reuse_variable: If the variable should be reused.
返回:
一个batch的数据上的总loss
def_build_deeplab(iterator, outputs_to_num_classes, ignore_label):
samples = iterator.get_next()# Add name to input and label nodes so we can add to summary.
samples[common.IMAGE]= tf.identity(samples[common.IMAGE], name=common.IMAGE)
samples[common.LABEL]= tf.identity(samples[common.LABEL], name=common.LABEL)
model_options = common.ModelOptions(
outputs_to_num_classes=outputs_to_num_classes,
crop_size=[int(sz)for sz in FLAGS.train_crop_size],
atrous_rates=FLAGS.atrous_rates,
output_stride=FLAGS.output_stride)
outputs_to_scales_to_logits = model.multi_scale_logits(
samples[common.IMAGE],
model_options=model_options,
image_pyramid=FLAGS.image_pyramid,
weight_decay=FLAGS.weight_decay,
is_training=True,
fine_tune_batch_norm=FLAGS.fine_tune_batch_norm,
nas_training_hyper_parameters={'drop_path_keep_prob': FLAGS.drop_path_keep_prob,'total_training_steps': FLAGS.training_number_of_steps,})# Add name to graph node so we can add to summary.
output_type_dict = outputs_to_scales_to_logits[common.OUTPUT_TYPE]# 把模型输出得到的logits取出来
output_type_dict[model.MERGED_LOGITS_SCOPE]= tf.identity(
output_type_dict[model.MERGED_LOGITS_SCOPE], name=common.OUTPUT_TYPE)# 加个名字# output_type_dict['merged_logits'], name='semantic'# six是为了解决Python2和Python3代码兼容性而产生的,six=2*3# python2中six.iteritems返回的是生成器,当dict很大时,six.iteritems不占用内存# python3里面,dict.items改变了默认实现,也返回生成器,因此six.iteritems就没啥用了for output, num_classes in six.iteritems(outputs_to_num_classes):# output='semantic',num_classes=21
train_utils.add_softmax_cross_entropy_loss_for_each_scale(
outputs_to_scales_to_logits[output],
samples[common.LABEL],
num_classes,
ignore_label,
loss_weight=1.0,
upsample_logits=FLAGS.upsample_logits,
hard_example_mining_step=FLAGS.hard_example_mining_step,
top_k_percent_pixels=FLAGS.top_k_percent_pixels,
scope=output)# Log the summary
_log_summaries(samples[common.IMAGE], samples[common.LABEL], num_classes,
output_type_dict[model.MERGED_LOGITS_SCOPE])
五、_log_summaries
Logs the summaries for the model
参数:
input_image: 输入图片,shape: [batch_size, height, width, channel].
label: 标签,shape: [batch_size, height, width].
num_of_classes: 类别数
output: 模型的输出,shape: [batch_size, height, width]
def_log_summaries(input_image, label, num_of_classes, output):# Add summaries for model variables.for model_var in tf.model_variables():
tf.summary.histogram(model_var.op.name, model_var)# Add summaries for images, labels, semantic predictions.if FLAGS.save_summaries_images:# 保存输入图片、标签、预测图像到summary
tf.summary.image('samples/%s'% common.IMAGE, input_image)# Scale up summary image pixel values for better visualization.
pixel_scaling =max(1,255// num_of_classes)
summary_label = tf.cast(label * pixel_scaling, tf.uint8)
tf.summary.image('samples/%s'% common.LABEL, summary_label)
predictions = tf.expand_dims(tf.argmax(output,3),-1)
summary_predictions = tf.cast(predictions * pixel_scaling, tf.uint8)
tf.summary.image('samples/%s'% common.OUTPUT_TYPE, summary_predictions)
六、_average_gradients
计算平均梯度
参数:
tower_grads: List of lists of (gradient, variable) tuples.
返回:
List of pairs of (gradient, variable) where the gradient has been summed across all towers.
def_average_gradients(tower_grads):
average_grads =[]for grad_and_vars inzip(*tower_grads):# Note that each grad_and_vars looks like the following:# ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
grads, variables =zip(*grad_and_vars)
grad = tf.reduce_mean(tf.stack(grads, axis=0), axis=0)# All vars are of the same value, using the first tower here.
average_grads.append((grad, variables[0]))return average_grads