YOLOv3源碼閱讀之六:train.py

一、YOLO簡介

  YOLO(You Only Look Once)是一個高效的目標檢測算法,屬於One-Stage大家族,針對於Two-Stage目標檢測算法普遍存在的運算速度慢的缺點,YOLO創造性的提出了One-Stage。也就是將物體分類和物體定位在一個步驟中完成。YOLO直接在輸出層迴歸bounding box的位置和bounding box所屬類別,從而實現one-stage。

  經過兩次迭代,YOLO目前的最新版本爲YOLOv3,在前兩版的基礎上,YOLOv3進行了一些比較細節的改動,效果有所提升。

  本文正是希望可以將源碼加以註釋,方便自己學習,同時也願意分享出來和大家一起學習。由於本人還是一學生,如果有錯還請大家不吝指出。

  本文參考的源碼地址爲:https://github.com/wizyoung/YOLOv3_TensorFlow

二、代碼和註釋

  文件目錄:YOUR_PATH\YOLOv3_TensorFlow-master\train.py

  這一部分代碼主要是訓練模型的入口,按照要求準備號訓練數據之後,就可以從這裏開始訓練了。

# coding: utf-8

from __future__ import division, print_function

import tensorflow as tf
import numpy as np
import logging
from tqdm import trange

import args

from utils.data_utils import get_batch_data
from utils.misc_utils import shuffle_and_overwrite, make_summary, config_learning_rate, config_optimizer, AverageMeter
from utils.eval_utils import evaluate_on_cpu, evaluate_on_gpu, get_preds_gpu, voc_eval, parse_gt_rec
from utils.nms_utils import gpu_nms

from model import yolov3

# setting loggers
# 設置日誌記錄
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s',
                    datefmt='%a, %d %b %Y %H:%M:%S', filename=args.progress_log_path, filemode='w')

# setting placeholders
# 整個網絡的數據輸入入口

# 是否是訓練階段,針對BN等操作有用
is_training = tf.placeholder(tf.bool, name="phase_train")

# 這個數據輸入入口未被使用,原因不明
handle_flag = tf.placeholder(tf.string, [], name='iterator_handle_flag')

# register the gpu nms operation here for the following evaluation scheme
# 爲了後面的模型評估的計算,這裏首先定義好在gpu上的nms操作
pred_boxes_flag = tf.placeholder(tf.float32, [1, None, None])
pred_scores_flag = tf.placeholder(tf.float32, [1, None, None])
gpu_nms_op = gpu_nms(pred_boxes_flag, pred_scores_flag, args.class_num, args.nms_topk, args.score_threshold, args.nms_threshold)

##################
# tf.data pipeline
##################
# 輸入輸入流,我們是從一個文本文件讀入數據,因此,可以使用TextLineDataset類來幫助數據讀入
train_dataset = tf.data.TextLineDataset(args.train_file)
# 隨機打亂
train_dataset = train_dataset.shuffle(args.train_img_cnt)
# 設定batch size
train_dataset = train_dataset.batch(args.batch_size)
# 自定義輸入的返回格式,因爲文本文件中的數據不一定就是正式的使用數據,可以自定義真正的數據讀取操作
train_dataset = train_dataset.map(
    lambda x: tf.py_func(get_batch_data,
                         inp=[x, args.class_num, args.img_size, args.anchors, 'train', args.multi_scale_train, args.use_mix_up],
                         Tout=[tf.int64, tf.float32, tf.float32, tf.float32, tf.float32]),
    num_parallel_calls=args.num_threads
)
# 預先讀取
train_dataset = train_dataset.prefetch(args.prefetech_buffer)

# 和訓練數據的讀取類似,這裏讀取的是驗證集的數據
val_dataset = tf.data.TextLineDataset(args.val_file)
val_dataset = val_dataset.batch(1)
val_dataset = val_dataset.map(
    lambda x: tf.py_func(get_batch_data,
                         inp=[x, args.class_num, args.img_size, args.anchors, 'val', False, False],
                         Tout=[tf.int64, tf.float32, tf.float32, tf.float32, tf.float32]),
    num_parallel_calls=args.num_threads
)
val_dataset.prefetch(args.prefetech_buffer)

# 定義迭代器
iterator = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)
train_init_op = iterator.make_initializer(train_dataset)
val_init_op = iterator.make_initializer(val_dataset)

# get an element from the chosen dataset iterator
# 利用迭代器獲取數據.由於之前我們自定義了數據的讀取方式,這裏返回的正是我們希望的數據
image_ids, image, y_true_13, y_true_26, y_true_52 = iterator.get_next()
y_true = [y_true_13, y_true_26, y_true_52]

# tf.data pipeline will lose the data `static` shape, so we need to set it manually
# 手動設置shape
image_ids.set_shape([None])
image.set_shape([None, None, None, 3])
for y in y_true:
    y.set_shape([None, None, None, None, None])

##################
# Model definition
##################
# 模型定義,這一部分和預測時的一致.
yolo_model = yolov3(args.class_num, args.anchors, args.use_label_smooth, args.use_focal_loss, args.batch_norm_decay, args.weight_decay)
with tf.variable_scope('yolov3'):
    pred_feature_maps = yolo_model.forward(image, is_training=is_training)

# 計算損失
loss = yolo_model.compute_loss(pred_feature_maps, y_true)

# 計算預測的結果
y_pred = yolo_model.predict(pred_feature_maps)

# 正則化的損失
l2_loss = tf.losses.get_regularization_loss()

# setting restore parts and vars to update
# 定義Saver,
saver_to_restore = tf.train.Saver(var_list=tf.contrib.framework.get_variables_to_restore(include=args.restore_part))
update_vars = tf.contrib.framework.get_variables_to_restore(include=args.update_part)

# 這一部分是爲了tensor board可視化做的準備,主要是一些曲線,反映loss的變化
tf.summary.scalar('train_batch_statistics/total_loss', loss[0])
tf.summary.scalar('train_batch_statistics/loss_xy', loss[1])
tf.summary.scalar('train_batch_statistics/loss_wh', loss[2])
tf.summary.scalar('train_batch_statistics/loss_conf', loss[3])
tf.summary.scalar('train_batch_statistics/loss_class', loss[4])
tf.summary.scalar('train_batch_statistics/loss_l2', l2_loss)
tf.summary.scalar('train_batch_statistics/loss_ratio', l2_loss / loss[0])

# global step
global_step = tf.Variable(float(args.global_step), trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])

# 是否使用warm up,默認是True,主要是定義學習率的方法上有些區別
if args.use_warm_up:
    learning_rate = tf.cond(tf.less(global_step, args.train_batch_num * args.warm_up_epoch), 
                            lambda: args.learning_rate_init * global_step / (args.train_batch_num * args.warm_up_epoch),
                            lambda: config_learning_rate(args, global_step - args.train_batch_num * args.warm_up_epoch))
else:
    learning_rate = config_learning_rate(args, global_step)
tf.summary.scalar('learning_rate', learning_rate)

#
if not args.save_optimizer:
    saver_to_save = tf.train.Saver()
    saver_best = tf.train.Saver()

# 優化器
optimizer = config_optimizer(args.optimizer_name, learning_rate)

if args.save_optimizer:
    saver_to_save = tf.train.Saver()
    saver_best = tf.train.Saver()

# set dependencies for BN ops
# 爲BN操作設置依賴
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    train_op = optimizer.minimize(loss[0] + l2_loss, var_list=update_vars, global_step=global_step)

# 設置會話Session
with tf.Session() as sess:
    # 初始化全局的variable
    sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])

    saver_to_restore.restore(sess, args.restore_path)
    merged = tf.summary.merge_all()
    writer = tf.summary.FileWriter(args.log_dir, sess.graph)

    print('\n----------- start to train -----------\n')

    best_mAP = -np.Inf

    # 開始循環訓練
    for epoch in range(args.total_epoches):

        sess.run(train_init_op)

        # 定義記錄數據的類,主要是保存當前爲止的所有數據的均值
        loss_total, loss_xy, loss_wh, loss_conf, loss_class = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()

        # 對每一個bacth size
        for i in trange(args.train_batch_num):
            _, summary, __y_pred, __y_true, __loss, __global_step, __lr = sess.run(
                [train_op, merged, y_pred, y_true, loss, global_step, learning_rate],
                feed_dict={is_training: True})

            writer.add_summary(summary, global_step=__global_step)

            # 更新均值
            loss_total.update(__loss[0], len(__y_pred[0]))
            loss_xy.update(__loss[1], len(__y_pred[0]))
            loss_wh.update(__loss[2], len(__y_pred[0]))
            loss_conf.update(__loss[3], len(__y_pred[0]))
            loss_class.update(__loss[4], len(__y_pred[0]))

            # 每隔一段時間進行模型的評估,這裏主要計算的是recall和precision
            # 這裏計算的是訓練集上的評估結果
            if __global_step % args.train_evaluation_step == 0 and __global_step > 0:
                # recall, precision = evaluate_on_cpu(__y_pred, __y_true, args.class_num, args.nms_topk, args.score_threshold, args.eval_threshold)
                recall, precision = evaluate_on_gpu(sess, gpu_nms_op, pred_boxes_flag, pred_scores_flag, __y_pred, __y_true, args.class_num, args.eval_threshold)

                info = "Epoch: {}, global_step: {} | loss: total: {:.2f}, xy: {:.2f}, wh: {:.2f}, conf: {:.2f}, class: {:.2f} | ".format(
                        epoch, int(__global_step), loss_total.average, loss_xy.average, loss_wh.average, loss_conf.average, loss_class.average)
                info += 'Last batch: rec: {:.3f}, prec: {:.3f} | lr: {:.5g}'.format(recall, precision, __lr)
                print(info)
                logging.info(info)

                writer.add_summary(make_summary('evaluation/train_batch_recall', recall), global_step=__global_step)
                writer.add_summary(make_summary('evaluation/train_batch_precision', precision), global_step=__global_step)

                if np.isnan(loss_total.average):
                    print('****' * 10)
                    raise ArithmeticError(
                        'Gradient exploded! Please train again and you may need modify some parameters.')

        # 重置相關的均值記錄類
        tmp_total_loss = loss_total.average
        loss_total.reset()
        loss_xy.reset()
        loss_wh.reset()
        loss_conf.reset()
        loss_class.reset()

        # 保存模型
        # NOTE: this is just demo. You can set the conditions when to save the weights.
        if epoch % args.save_epoch == 0 and epoch > 0:
            if tmp_total_loss <= 2.:
                saver_to_save.save(sess, args.save_dir + 'model-epoch_{}_step_{}_loss_{:.4f}_lr_{:.5g}'.format(epoch, int(__global_step), loss_total.last_avg, __lr))

        # 驗證集用以評估模型,這一部分和前面類似
        # switch to validation dataset for evaluation
        if epoch % args.val_evaluation_epoch == 0 and epoch > 0:
            sess.run(val_init_op)

            val_loss_total, val_loss_xy, val_loss_wh, val_loss_conf, val_loss_class = \
                AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()

            val_preds = []

            for j in trange(args.val_img_cnt):
                __image_ids, __y_pred, __loss = sess.run([image_ids, y_pred, loss],
                                                         feed_dict={is_training: False})
                pred_content = get_preds_gpu(sess, gpu_nms_op, pred_boxes_flag, pred_scores_flag, __image_ids, __y_pred)
                val_preds.extend(pred_content)
                val_loss_total.update(__loss[0])
                val_loss_xy.update(__loss[1])
                val_loss_wh.update(__loss[2])
                val_loss_conf.update(__loss[3])
                val_loss_class.update(__loss[4])

            # calc mAP
            # 計算mAP
            rec_total, prec_total, ap_total = AverageMeter(), AverageMeter(), AverageMeter()
            gt_dict = parse_gt_rec(args.val_file, args.img_size)

            info = '======> Epoch: {}, global_step: {}, lr: {:.6g} <======\n'.format(epoch, __global_step, __lr)

            for ii in range(args.class_num):
                npos, nd, rec, prec, ap = voc_eval(gt_dict, val_preds, ii, iou_thres=args.eval_threshold, use_07_metric=False)
                info += 'EVAL: Class {}: Recall: {:.4f}, Precision: {:.4f}, AP: {:.4f}\n'.format(ii, rec, prec, ap)
                rec_total.update(rec, npos)
                prec_total.update(prec, nd)
                ap_total.update(ap, 1)

            mAP = ap_total.avg
            info += 'EVAL: Recall: {:.4f}, Precison: {:.4f}, mAP: {:.4f}\n'.format(rec_total.avg, prec_total.avg, mAP)
            info += 'EVAL: loss: total: {:.2f}, xy: {:.2f}, wh: {:.2f}, conf: {:.2f}, class: {:.2f}\n'.format(
                val_loss_total.avg, val_loss_xy.avg, val_loss_wh.avg, val_loss_conf.avg, val_loss_class.avg)
            print(info)
            logging.info(info)

            if mAP > best_mAP:
                best_mAP = mAP
                saver_best.save(sess, args.save_dir + 'best_model_Epoch_{}_step_{}_mAP_{:.4f}_loss_{:.4f}_lr_{:.7g}'.format(
                                   epoch, __global_step, best_mAP, val_loss_total.last_avg, __lr))
            
            writer.add_summary(make_summary('evaluation/val_mAP', mAP), global_step=epoch)
            writer.add_summary(make_summary('evaluation/val_recall', rec_total.last_avg), global_step=epoch)
            writer.add_summary(make_summary('evaluation/val_precision', prec_total.last_avg), global_step=epoch)
            writer.add_summary(make_summary('validation_statistics/total_loss', val_loss_total.last_avg), global_step=epoch)
            writer.add_summary(make_summary('validation_statistics/loss_xy', val_loss_xy.last_avg), global_step=epoch)
            writer.add_summary(make_summary('validation_statistics/loss_wh', val_loss_wh.last_avg), global_step=epoch)
            writer.add_summary(make_summary('validation_statistics/loss_conf', val_loss_conf.last_avg), global_step=epoch)
            writer.add_summary(make_summary('validation_statistics/loss_class', val_loss_class.last_avg), global_step=epoch)



發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章