[yolov3] 檢查 yolov3 數據標註是否正確

使用方法:

1.安裝依賴包

2.修改:

*2.1 labels,有幾類,寫幾類,按0-n增大順序

*2.2 img_dir,修改輸入圖片路徑

*2.3 yolo_txt_dir,修改爲yolo.txt所在文件路徑

× 2.4 scale_percent,最終顯示的圖片的縮放比

× 2.5 check_rate,檢查的圖片的比例

× 2.6 color_list,每一類的物體的標記的顏色(RGB)

3.使用腳本

requirements.txt:

cycler==0.10.0
kiwisolver==1.1.0
matplotlib==3.1.2
numpy==1.18.1
opencv-python==4.1.2.30
Pillow==7.0.0
pyparsing==2.4.6
python-dateutil==2.8.1
six==1.13.0

Scripts:

# -*- coding: utf-8 -*-

import os
import random
import cv2 as cv
import matplotlib.pyplot as plt




labels = ["ArmorBlue", "ArmorRed", "Base", "Watcher"]
color_list = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (0, 255, 255)]
img_dir = "/home/youyheng/DJIdata/robomaster_Final_Tournament/image"
yolo_txt_dir = "/home/youyheng/DJIdata/robomaster_Final_Tournament/processedTXT"
# result_dst_dir = "/home/youyheng/DJIdata/robomaster_Final_Tournament/check_label_result"
scale_percent = 80
# rates that represent the imgs of all datasets
# 1 for all imgs, 0.5 for half of the imgs
check_rate = 1
random_check = False

def cv_imread(file_path):
    img = plt.imread(file_path)
    img_rgb = cv.cvtColor(img, cv.COLOR_BGR2RGB)
    return img_rgb


def my_line(img, start, end):
    thickness = 2
    line_type = 8
    cv.line(img,
             start,
             end,
             (0, 0, 0),
             thickness,
             line_type)


# draw rectangle with the data caught in the data file
# And set the name of the label to it
def draw_label_rec(img, label_index, label_info_list, img_name):
    global labels

    img_height = img.shape[0]
    img_width = img.shape[1]

    x = float(label_info_list[0])
    y = float(label_info_list[1])
    w = float(label_info_list[2])
    h = float(label_info_list[3])

    x_center = x * img_width
    y_center = y * img_height

    xmax = int(x_center + w * img_width / 2)
    xmin = int(x_center - w * img_width / 2)
    ymax = int(y_center + w * img_height / 2)
    ymin = int(y_center - w * img_height / 2)

    # Set font
    font = cv.FONT_HERSHEY_SIMPLEX
    global color_list
    
    # draw_rectangle
    cv.rectangle(img,  # img to paint on
             (xmin, ymin),  # bottom top
             (xmax, ymax),  # bottom right
             color_list[int(label_index)],  # bgr color
             2)  # line thickness

    ###########need perfection
    cv.putText(img, str(img_name), (5, 50), cv.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)


def main():
    global img_dir, yolo_txt_dir, labels, random_check

    origin_window = "Origin Window"

    # Load all imgs with label info
    img_name_list = os.listdir(img_dir)
    if random_check is True:
        random.shuffle(img_name_list)

    check_max_times = int(check_rate * len(img_name_list))
    for index, img_name in enumerate(img_name_list):
        if not img_name.endswith('jpg'):
            continue

        # Checked for max_times and quit
        if index >= check_max_times:
            return
        print("**check img : {0} **".format(os.path.join(img_dir, img_name)))
        # Open IMG
        src_image = cv_imread(os.path.join(img_dir, img_name))

        # Open yolo label txt
        if os.path.exists(os.path.join(yolo_txt_dir, img_name.rpartition(".")[0]+".txt")):
            file_reader = open(os.path.join(yolo_txt_dir, img_name.rpartition(".")[0]+".txt"), "r")
        else:
            continue

        ## Dada loaded ##
        if src_image is None:
            print("Open image Error")
            return

        if file_reader is None:
            print("Open txt error")
            return

        # Pre-handling for Img
        src_height = src_image.shape[0]
        src_width = src_image.shape[1]

        # percent of original size
        global scale_percent
        width = int(src_width * scale_percent / 100)
        height = int(src_height * scale_percent / 100)
        dim = (width, height)

        # Decode the data
        while True:
            line = file_reader.readline()
            if not line:
                break
            label_info_list = line.split()
            # Get 5 nums in labeled_obj_info_list:
            # labels[label_info_list[0]] obj type : 0 ArmorBlue, 1 ArmorRed, 2 Base, 3 Watcher
            # label_info_list[1] x
            # label_info_list[2] y
            # label_info_list[3] w
            # label_info_list[4] h
            label_index = int(label_info_list[0])
            x = label_info_list[1]
            y = label_info_list[2]
            w = label_info_list[3]
            h = label_info_list[4]

            ########################
            # need perfection
            draw_label_rec(src_image, label_index, [x, y, w, h], img_name)

        resized_src = cv.resize(src_image, dim, interpolation=cv.INTER_CUBIC)

        # show the result
        cv.imshow(origin_window, resized_src)
        cv.waitKey(0)

        # Debug
        # print("src_height = {0}".format(src_height))
        # print("src_width = {0}".format(src_width))
        cv.destroyAllWindows()

        file_reader.close()
        print("**check over**")


if __name__ == "__main__":
    main()

效果預覽:(按任意鍵跳轉到下一張圖片,由於opencv不能讀取中文路徑,所以路徑中有中文將會停止檢查)

2020/01/08 解決中文路徑問題

效果預覽:

有中文路徑的情況:

 

發佈了214 篇原創文章 · 獲贊 54 · 訪問量 5萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章