yolov3(keras-tf)多目標檢測與數據標註

數據準備

如果要識別的目標能找到數據集,可下載;標註形式爲:
image.jpg x1,y1,x2,y2,class1
下載voc、coco數據集可直接執行其代碼生成dataset進行訓練。

數據標註

這裏採用opencv-python交互及多目標跟蹤進行多目標標註和保存,生成yolo直接讀取的dataset格式。
適用條件爲目標連續出現易跟蹤的視頻。
使用方法:
輸入視頻文件,每個目標box左鍵點擊左上、右下兩點,再鍵盤輸入1-9的標籤,該幀boxs標註完之後,按esc進行視頻多目標跟蹤、保存。效果如下圖:
在這裏插入圖片描述代碼:

# -*- coding: utf-8
import cv2
import numpy as np
import os
'''
	data: 2020/04/03
	author: Jiang
	function:
		根據跟蹤來提取目標正樣本
		先手動標記左上、右下兩點
		選擇了一個box後按‘1-9’進行標籤存取;所有box標記完後按esc進行保存,跟蹤。
		如若偏差大,按 'esc' 標定;按 ‘q' 退出
	bug:
		需要增加一個按鍵命令進行無目標視頻放映
'''

tracker_types = ['BOOSTING', 'MIL','KCF', 'TLD', 'MEDIANFLOW', 'MOSSE', 'CSRT']
def createTrackerByName(choose_tracker):
    # Set up tracker.
    # Instead of MIL, you can also use
    
    tracker_type = tracker_types[choose_tracker]
 
 
    if tracker_type == 'BOOSTING':				#基於adboost 抖動
        tracker = cv2.TrackerBoosting_create()
    if tracker_type == 'MIL':                             #效果差
        tracker = cv2.TrackerMIL_create()
    if tracker_type == 'KCF':				#太慢了
        tracker = cv2.TrackerKCF_create()
    if tracker_type == 'TLD':				#誤報率高
        tracker = cv2.TrackerTLD_create()
    if tracker_type == 'MEDIANFLOW':			#效果好!!!
        tracker = cv2.TrackerMedianFlow_create()
    if tracker_type == "CSRT": 				#效果差
        tracker = cv2.TrackerCSRT_create()
    if tracker_type == "MOSSE":				#跟蹤效果可以,幀率比MEDIANFLOW低,遠距離時會擴大目標框
        tracker = cv2.TrackerMOSSE_create()
    return tracker

#跟蹤,保存樣本
def tracking_save(cap,frame,target_boxes,classes,path):
    global count_rects
    #判斷視頻流是攝像頭還是文件
    path = str(path)
    if len(path) == 1:
        pos_name = path
    else:
        pos_name = path.split('.')[0]
    choose_tracker = 6
    if os.path.exists(pos_name) == False:
        os.mkdir(pos_name)  
    multiTracker = cv2.MultiTracker_create()
    # Initialize MultiTracker
    for bbox in target_boxes:
        #print(bbox)
        multiTracker.add(createTrackerByName(choose_tracker),frame, bbox)
    count = 0
    save = 0
    f = open(pos_name+'.txt','w')
    while True:
            # Read a new frame
            ok, frame = cap.read()
            save_frame = frame.copy()
            if not ok:
                break
            # Start timer
            timer = cv2.getTickCount()
            # Update tracker
            ok, boxes = multiTracker.update(frame)
            # Calculate Frames per second (FPS)
            fps = cv2.getTickFrequency() / (cv2.getTickCount() - timer)
            # Draw bounding box
            if ok:
                # Tracking success
                save_boxes = ""
                the_class = 0
                class_num = max(classes)#從1開始
                color_step = 255*3//class_num
                colors = []
                for i in range(class_num):
                    color_value = i*color_step
                    if color_value < 255:
                        b,g,r = color_value,0,0
                    elif 255 <=  color_value <= 255*2:
                        b,g,r = 255,color_value-255,0
                    else:
                        b,g,r = 255,255,color_value-255*2
                    color = (b,g,r)
                    if color == (0,0,0) :
                        color = (0,0,255)
                    elif color== (255,255,255):
                        color = (255,0,0)
                    colors.append(color)

                for i, newbox in enumerate(boxes):
                    th_class = classes[i]
                    color = colors[th_class - 1]
                    p1 = (int(newbox[0]), int(newbox[1]))
                    p2 = (int(newbox[0] + newbox[2]), int(newbox[1] + newbox[3]))
                    cv2.rectangle(frame, p1, p2, color, 2, 1)
                    save_boxes += " %s,%s,%s,%s,%s"%(p1[0],p1[1],p2[0],p2[1],th_class)
            
                #print (save_boxes,colors)
                count_rects += 1
                img_path =  pos_name + '/'+str(count_rects)+'.jpg'
                f.write(img_path + save_boxes +'\n')
                cv2.imwrite(img_path,save_frame)
                #cv2.imwrite(img_path,the_rect)#adboost樣本pos保存
            else :
            # Tracking failure
                cv2.putText(frame, "Tracking failure detected", (100,80), cv2.FONT_HERSHEY_SIMPLEX, 0.75,(0,0,255),2)
    
            # Display tracker type on frame
            cv2.putText(frame, tracker_types[choose_tracker] + " Tracker", (100,20), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (50,170,50),2)
            # Display FPS on frame
            cv2.putText(frame, "FPS : " + str(int(fps)), (100,50), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (50,170,50), 2)
            #drawMask(frame,bbox[0]+5,bbox[1]+5,bbox[2]-10,bbox[3]-10)  
            # Display result

            show_img = frame.copy()
            cv2.imshow("show", show_img)
            # cv2.imwrite("./save/"+str(save)+".jpg", show_img)
            # Exit if ESC pressed
            k = cv2.waitKey(10) & 0xff
            if k == 27 : 
                f.close()
                return 1
            elif k == ord('q'):
                #cv2.destroyWindow("Tracking")
                return 0
                f.close()
                
#鼠標回調函數s
def DrawCornerPoint(event,x,y,flags,param):
    global click_n,frame,corner_data
    if event == cv2.EVENT_LBUTTONDOWN :
        if click_n <= 1:
            corner_data.append((x,y))
            cv2.circle(frame,(x,y),2,(0,0,255),-1)#第一次打點
            click_n+=1
            print (click_n)
            cv2.imshow('show',frame)

#手動定位框
def get_bbox(cap):
        global click_n,frame,corner_data
        # Exit if video not opened.
        if not cap.isOpened():
            print("Could not open video")
            sys.exit()
    
        # Read first frame.
        ok, frame = cap.read()
        img0 = frame.copy()
        if not ok:
            print('Cannot read video file')
            sys.exit()
        else:#在該幀中確定目標box
            cv2.imshow('show',frame)
            cv2.setMouseCallback("show",DrawCornerPoint)
        boxes = []
        classes = []
        while 1:
            clicked_k = cv2.waitKey(20)
            if clicked_k > ord('0'):# 1 -9 可錄10類
                classes.append(clicked_k - ord('0'))#類別從1開始
                print(classes)
                #跟蹤框
                box = (corner_data[0][0], corner_data[0][1],\
                    corner_data[1][0] - corner_data[0][0], corner_data[1][1]-corner_data[0][1])#(x,y,w,h)
                boxes.append(box)
                click_n = 0
                corner_data = []#初始化box對角點
                print (clicked_k)
            
            elif clicked_k == 27:
                break
        print (boxes,classes)
        return img0,boxes,classes
        
#test 樣本與txt是否對應
def test_samples(pathtxt):
    f = open(pathtxt,'r')
    lines = f.readlines()
    print (len(lines))
    f.close()

    for i,line in enumerate(lines):
        img_path = line.split(' ')[0]
        #img = cv2.imread(img_path)

        if os.path.exists(img_path) == False:
            print(i,line)
        else:
            cv2.imshow('show',cv2.imread(img_path))
            cv2.waitKey(100)
    cv2.destroyAllWindows()

if __name__ == "__main__":
  # yolo Row format: 	image_file_path box1 box2 ... boxN;
  #Box format: 		x_min,y_min,x_max,y_max,class_id
    path = '104.mp4'
    cap = cv2.VideoCapture(path)
    frame = 0#標誌幀全局變量
    count_rects = 0#保存幀全局變量

    click_n = 0 #鼠標操作全局變量
    corner_data = []#對角點全局變量
    
    img0,boxes,classes = get_bbox(cap)
    key = tracking_save(cap,img0,boxes,classes,path)

    while key != 0:
        #按下了 esc 重新定位框
        count_rects -= 2#覆蓋掉之前兩幀
        click_n = 0
        corner_data = []
        img0,boxes,classes = get_bbox(cap)
        key = tracking_save(cap,img0,boxes,classes,path)
    #按下了 q ,退出
    cap.release()
    cv2.destroyAllWindows()
    test_samples(path.split('.')[0]+'.txt')

keras訓練

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章