YOLOv3源碼閱讀之一:test_single_image.py

一、YOLO簡介

  YOLO(You Only Look Once)是一個高效的目標檢測算法,屬於One-Stage大家族,針對於Two-Stage目標檢測算法普遍存在的運算速度慢的缺點,YOLO創造性的提出了One-Stage。也就是將物體分類和物體定位在一個步驟中完成。YOLO直接在輸出層迴歸bounding box的位置和bounding box所屬類別,從而實現one-stage。

  經過兩次迭代,YOLO目前的最新版本爲YOLOv3,在前兩版的基礎上,YOLOv3進行了一些比較細節的改動,效果有所提升。

  本文正是希望可以將源碼加以註釋,方便自己學習,同時也願意分享出來和大家一起學習。由於本人還是一學生,如果有錯還請大家不吝指出。

  本文參考的源碼地址爲:https://github.com/wizyoung/YOLOv3_TensorFlow

二、代碼和註釋

  文件目錄:YOUR_PATH\YOLOv3_TensorFlow-master\test_single_image.py

  需要注意的是,我們默認輸入圖片尺寸爲[416,416][416, 416]

# coding: utf-8

from __future__ import division, print_function

import tensorflow as tf
import numpy as np
import argparse
import cv2

from utils.misc_utils import parse_anchors, read_class_names
from utils.nms_utils import gpu_nms
from utils.plot_utils import get_color_table, plot_one_box

from model import yolov3

# 設置命令行參數,具體可參見每一個命令行參數的含義
parser = argparse.ArgumentParser(description="YOLO-V3 test single image test procedure.")
parser.add_argument("input_image", type=str,
                    help="The path of the input image.")
parser.add_argument("--anchor_path", type=str, default="./data/yolo_anchors.txt",
                    help="The path of the anchor txt file.")
parser.add_argument("--new_size", nargs='*', type=int, default=[416, 416],
                    help="Resize the input image with `new_size`, size format: [width, height]")
parser.add_argument("--class_name_path", type=str, default="./data/coco.names",
                    help="The path of the class names.")
parser.add_argument("--restore_path", type=str, default="./data/darknet_weights/yolov3.ckpt",
                    help="The path of the weights to restore.")
args = parser.parse_args()

# 處理anchors,這些anchors是通過數據聚類獲得,一共9個,shape爲:[9, 2]。
# 需要注意的是,最後一個維度的順序是[width, height]
args.anchors = parse_anchors(args.anchor_path)

# 處理classes, 這裏是將所有的class的名稱提取了出來,組成了一個列表
args.classes = read_class_names(args.class_name_path)

# 類別的數目
args.num_class = len(args.classes)

# 根據類別的數目爲每一個類別分配不同的顏色,以便展示
color_table = get_color_table(args.num_class)

# 讀取圖片
img_ori = cv2.imread(args.input_image)

# 獲取圖片的尺寸
height_ori, width_ori = img_ori.shape[:2]

# resize,根據之前設定的尺寸值進行resize,默認是[416, 416],還是[width, height]的順序
img = cv2.resize(img_ori, tuple(args.new_size))

# 對圖片像素進行一定的數據處理
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.asarray(img, np.float32)
img = img[np.newaxis, :] / 255.

# TF會話
with tf.Session() as sess:
    # 輸入的placeholder,用於輸入圖片
    input_data = tf.placeholder(tf.float32, [1, args.new_size[1], args.new_size[0], 3], name='input_data')
    # 定義一個YOLOv3的類,在後面可以用來做模型建立以及loss計算等操作,參數分別是類別的數目和anchors
    yolo_model = yolov3(args.num_class, args.anchors)
    with tf.variable_scope('yolov3'):
        # 對圖片進行正向傳播,返回多張特徵圖
        pred_feature_maps = yolo_model.forward(input_data, False)
    # 對這些特徵圖進行處理,獲得計算出的bounding box以及屬於前景的概率已經每一個類別的概率分佈
    pred_boxes, pred_confs, pred_probs = yolo_model.predict(pred_feature_maps)

    # 將兩個概率值分別相乘就可以獲得最終的概率值
    pred_scores = pred_confs * pred_probs

    # 對這些bounding boxes和概率值進行非最大抑制(NMS)就可以獲得最後的bounding boxes和與其對應的概率值以及標籤
    boxes, scores, labels = gpu_nms(pred_boxes, pred_scores, args.num_class, max_boxes=30, score_thresh=0.4, nms_thresh=0.5)

    # Saver類,用以保存和恢復模型
    saver = tf.train.Saver()
    # 恢復模型參數
    saver.restore(sess, args.restore_path)

    # 運行graph,獲得對應tensors的具體數值,這裏是[boxes, scores, labels],對應於NMS之後獲得的結果
    boxes_, scores_, labels_ = sess.run([boxes, scores, labels], feed_dict={input_data: img})

    # rescale the coordinates to the original image
    # 將座標重新映射到原始圖片上,因爲前面的計算都是在resize之後的圖片上進行的,所以需要進行映射
    boxes_[:, 0] *= (width_ori/float(args.new_size[0]))
    boxes_[:, 2] *= (width_ori/float(args.new_size[0]))
    boxes_[:, 1] *= (height_ori/float(args.new_size[1]))
    boxes_[:, 3] *= (height_ori/float(args.new_size[1]))

    # 輸出
    print("box coords:")
    print(boxes_)
    print('*' * 30)
    print("scores:")
    print(scores_)
    print('*' * 30)
    print("labels:")
    print(labels_)

    # 繪製並展示,保存最後的結果
    for i in range(len(boxes_)):
        x0, y0, x1, y1 = boxes_[i]
        plot_one_box(img_ori, [x0, y0, x1, y1], label=args.classes[labels_[i]], color=color_table[labels_[i]])
    cv2.imshow('Detection result', img_ori)
    cv2.imwrite('detection_result.jpg', img_ori)
    cv2.waitKey(0)

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章