yolov3(keras-tf)多目标检测与数据标注

数据准备

如果要识别的目标能找到数据集,可下载;标注形式为:
image.jpg x1,y1,x2,y2,class1
下载voc、coco数据集可直接执行其代码生成dataset进行训练。

数据标注

这里采用opencv-python交互及多目标跟踪进行多目标标注和保存,生成yolo直接读取的dataset格式。
适用条件为目标连续出现易跟踪的视频。
使用方法:
输入视频文件,每个目标box左键点击左上、右下两点,再键盘输入1-9的标签,该帧boxs标注完之后,按esc进行视频多目标跟踪、保存。效果如下图:
在这里插入图片描述代码:

# -*- coding: utf-8
import cv2
import numpy as np
import os
'''
	data: 2020/04/03
	author: Jiang
	function:
		根据跟踪来提取目标正样本
		先手动标记左上、右下两点
		选择了一个box后按‘1-9’进行标签存取;所有box标记完后按esc进行保存,跟踪。
		如若偏差大,按 'esc' 标定;按 ‘q' 退出
	bug:
		需要增加一个按键命令进行无目标视频放映
'''

tracker_types = ['BOOSTING', 'MIL','KCF', 'TLD', 'MEDIANFLOW', 'MOSSE', 'CSRT']
def createTrackerByName(choose_tracker):
    # Set up tracker.
    # Instead of MIL, you can also use
    
    tracker_type = tracker_types[choose_tracker]
 
 
    if tracker_type == 'BOOSTING':				#基于adboost 抖动
        tracker = cv2.TrackerBoosting_create()
    if tracker_type == 'MIL':                             #效果差
        tracker = cv2.TrackerMIL_create()
    if tracker_type == 'KCF':				#太慢了
        tracker = cv2.TrackerKCF_create()
    if tracker_type == 'TLD':				#误报率高
        tracker = cv2.TrackerTLD_create()
    if tracker_type == 'MEDIANFLOW':			#效果好!!!
        tracker = cv2.TrackerMedianFlow_create()
    if tracker_type == "CSRT": 				#效果差
        tracker = cv2.TrackerCSRT_create()
    if tracker_type == "MOSSE":				#跟踪效果可以,帧率比MEDIANFLOW低,远距离时会扩大目标框
        tracker = cv2.TrackerMOSSE_create()
    return tracker

#跟踪,保存样本
def tracking_save(cap,frame,target_boxes,classes,path):
    global count_rects
    #判断视频流是摄像头还是文件
    path = str(path)
    if len(path) == 1:
        pos_name = path
    else:
        pos_name = path.split('.')[0]
    choose_tracker = 6
    if os.path.exists(pos_name) == False:
        os.mkdir(pos_name)  
    multiTracker = cv2.MultiTracker_create()
    # Initialize MultiTracker
    for bbox in target_boxes:
        #print(bbox)
        multiTracker.add(createTrackerByName(choose_tracker),frame, bbox)
    count = 0
    save = 0
    f = open(pos_name+'.txt','w')
    while True:
            # Read a new frame
            ok, frame = cap.read()
            save_frame = frame.copy()
            if not ok:
                break
            # Start timer
            timer = cv2.getTickCount()
            # Update tracker
            ok, boxes = multiTracker.update(frame)
            # Calculate Frames per second (FPS)
            fps = cv2.getTickFrequency() / (cv2.getTickCount() - timer)
            # Draw bounding box
            if ok:
                # Tracking success
                save_boxes = ""
                the_class = 0
                class_num = max(classes)#从1开始
                color_step = 255*3//class_num
                colors = []
                for i in range(class_num):
                    color_value = i*color_step
                    if color_value < 255:
                        b,g,r = color_value,0,0
                    elif 255 <=  color_value <= 255*2:
                        b,g,r = 255,color_value-255,0
                    else:
                        b,g,r = 255,255,color_value-255*2
                    color = (b,g,r)
                    if color == (0,0,0) :
                        color = (0,0,255)
                    elif color== (255,255,255):
                        color = (255,0,0)
                    colors.append(color)

                for i, newbox in enumerate(boxes):
                    th_class = classes[i]
                    color = colors[th_class - 1]
                    p1 = (int(newbox[0]), int(newbox[1]))
                    p2 = (int(newbox[0] + newbox[2]), int(newbox[1] + newbox[3]))
                    cv2.rectangle(frame, p1, p2, color, 2, 1)
                    save_boxes += " %s,%s,%s,%s,%s"%(p1[0],p1[1],p2[0],p2[1],th_class)
            
                #print (save_boxes,colors)
                count_rects += 1
                img_path =  pos_name + '/'+str(count_rects)+'.jpg'
                f.write(img_path + save_boxes +'\n')
                cv2.imwrite(img_path,save_frame)
                #cv2.imwrite(img_path,the_rect)#adboost样本pos保存
            else :
            # Tracking failure
                cv2.putText(frame, "Tracking failure detected", (100,80), cv2.FONT_HERSHEY_SIMPLEX, 0.75,(0,0,255),2)
    
            # Display tracker type on frame
            cv2.putText(frame, tracker_types[choose_tracker] + " Tracker", (100,20), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (50,170,50),2)
            # Display FPS on frame
            cv2.putText(frame, "FPS : " + str(int(fps)), (100,50), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (50,170,50), 2)
            #drawMask(frame,bbox[0]+5,bbox[1]+5,bbox[2]-10,bbox[3]-10)  
            # Display result

            show_img = frame.copy()
            cv2.imshow("show", show_img)
            # cv2.imwrite("./save/"+str(save)+".jpg", show_img)
            # Exit if ESC pressed
            k = cv2.waitKey(10) & 0xff
            if k == 27 : 
                f.close()
                return 1
            elif k == ord('q'):
                #cv2.destroyWindow("Tracking")
                return 0
                f.close()
                
#鼠标回调函数s
def DrawCornerPoint(event,x,y,flags,param):
    global click_n,frame,corner_data
    if event == cv2.EVENT_LBUTTONDOWN :
        if click_n <= 1:
            corner_data.append((x,y))
            cv2.circle(frame,(x,y),2,(0,0,255),-1)#第一次打点
            click_n+=1
            print (click_n)
            cv2.imshow('show',frame)

#手动定位框
def get_bbox(cap):
        global click_n,frame,corner_data
        # Exit if video not opened.
        if not cap.isOpened():
            print("Could not open video")
            sys.exit()
    
        # Read first frame.
        ok, frame = cap.read()
        img0 = frame.copy()
        if not ok:
            print('Cannot read video file')
            sys.exit()
        else:#在该帧中确定目标box
            cv2.imshow('show',frame)
            cv2.setMouseCallback("show",DrawCornerPoint)
        boxes = []
        classes = []
        while 1:
            clicked_k = cv2.waitKey(20)
            if clicked_k > ord('0'):# 1 -9 可录10类
                classes.append(clicked_k - ord('0'))#类别从1开始
                print(classes)
                #跟踪框
                box = (corner_data[0][0], corner_data[0][1],\
                    corner_data[1][0] - corner_data[0][0], corner_data[1][1]-corner_data[0][1])#(x,y,w,h)
                boxes.append(box)
                click_n = 0
                corner_data = []#初始化box对角点
                print (clicked_k)
            
            elif clicked_k == 27:
                break
        print (boxes,classes)
        return img0,boxes,classes
        
#test 样本与txt是否对应
def test_samples(pathtxt):
    f = open(pathtxt,'r')
    lines = f.readlines()
    print (len(lines))
    f.close()

    for i,line in enumerate(lines):
        img_path = line.split(' ')[0]
        #img = cv2.imread(img_path)

        if os.path.exists(img_path) == False:
            print(i,line)
        else:
            cv2.imshow('show',cv2.imread(img_path))
            cv2.waitKey(100)
    cv2.destroyAllWindows()

if __name__ == "__main__":
  # yolo Row format: 	image_file_path box1 box2 ... boxN;
  #Box format: 		x_min,y_min,x_max,y_max,class_id
    path = '104.mp4'
    cap = cv2.VideoCapture(path)
    frame = 0#标志帧全局变量
    count_rects = 0#保存帧全局变量

    click_n = 0 #鼠标操作全局变量
    corner_data = []#对角点全局变量
    
    img0,boxes,classes = get_bbox(cap)
    key = tracking_save(cap,img0,boxes,classes,path)

    while key != 0:
        #按下了 esc 重新定位框
        count_rects -= 2#覆盖掉之前两帧
        click_n = 0
        corner_data = []
        img0,boxes,classes = get_bbox(cap)
        key = tracking_save(cap,img0,boxes,classes,path)
    #按下了 q ,退出
    cap.release()
    cv2.destroyAllWindows()
    test_samples(path.split('.')[0]+'.txt')

keras训练

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章