【fork】openpose檢測的多種情況(多人,單人,圖像,視頻,計算角度)

1.來源

openpose檢測的部分代碼和模型來自一個存儲庫,具體的連接忘了,日後找到的話補上。

另外,單純的姿態檢測的話不需要深度學習框架,只要有opencv即可。如果要單人檢測,就需要用到目標檢測算法,使用的是yolo,需要用到pytorch。

2.使用openpose檢測的代碼

# -*- coding: utf-8 -*-
"""
Created on Fri May 24 23:18:36 2019

@author: wangwei
"""

import cv2
import time
import numpy as np
from random import randint

# 供內部調用的函數
def getKeypoints(probMap, threshold=0.1):
    
    mapSmooth = cv2.GaussianBlur(probMap, (3,3), 0, 0)
    mapMask = np.uint8(mapSmooth > threshold)
    keypoints = []

    #find the blobs
    # 可能會遇到opencv版本不對的問題,導致下面的函數返回值不一樣,刪除第一個下劃線就行了
    _, contours, _ = cv2.findContours(mapMask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

    #for each blob find the maxima
    for cnt in contours:
        blobMask = np.zeros(mapMask.shape)
        blobMask = cv2.fillConvexPoly(blobMask, cnt, 1)
        maskedProbMap = mapSmooth * blobMask
        _, maxVal, _, maxLoc = cv2.minMaxLoc(maskedProbMap)
        keypoints.append(maxLoc + (probMap[maxLoc[1], maxLoc[0]],))
        keypoints_temp = keypoints
        
        
    #print('===========keypoints========{}'.format(keypoints))
    
    return keypoints


# 供內部調用的函數
# Find valid connections between the different joints of a all persons present
def getValidPairs(output, mapIdx, frameWidth, frameHeight, POSE_PAIRS, detected_keypoints):

    valid_pairs = []
    invalid_pairs = []
    n_interp_samples = 10
    paf_score_th = 0.1
    conf_th = 0.7
    
    # loop for every POSE_PAIR
    for k in range(len(mapIdx)):
        # A->B constitute a limb
        pafA = output[0, mapIdx[k][0], :, :]
        pafB = output[0, mapIdx[k][1], :, :]
        pafA = cv2.resize(pafA, (frameWidth, frameHeight))
        pafB = cv2.resize(pafB, (frameWidth, frameHeight))

        # Find the keypoints for the first and second limb
        candA = detected_keypoints[POSE_PAIRS[k][0]]
        candB = detected_keypoints[POSE_PAIRS[k][1]]
        nA = len(candA)
        nB = len(candB)


        # If keypoints for the joint-pair is detected
        # check every joint in candA with every joint in candB
        # Calculate the distance vector between the two joints
        # Find the PAF values at a set of interpolated points between the joints
        # Use the above formula to compute a score to mark the connection valid


        if( nA != 0 and nB != 0):
            valid_pair = np.zeros((0,3))
            for i in range(nA):
                max_j=-1
                maxScore = -1
                found = 0
                for j in range(nB):
                    # Find d_ij
                    d_ij = np.subtract(candB[j][:2], candA[i][:2])
                    norm = np.linalg.norm(d_ij)
                    if norm:
                        d_ij = d_ij / norm
                    else:
                        continue
                        
                    # Find p(u)
                    interp_coord = list(zip(np.linspace(candA[i][0], candB[j][0], num=n_interp_samples),
                                            np.linspace(candA[i][1], candB[j][1], num=n_interp_samples)))

                    # Find L(p(u))
                    paf_interp = []
                    for k in range(len(interp_coord)):
                        paf_interp.append([pafA[int(round(interp_coord[k][1])), int(round(interp_coord[k][0]))],
                                           pafB[int(round(interp_coord[k][1])), int(round(interp_coord[k][0]))] ])

                    # Find E
                    paf_scores = np.dot(paf_interp, d_ij)
                    avg_paf_score = sum(paf_scores)/len(paf_scores)

                    # Check if the connection is valid
                    # If the fraction of interpolated vectors aligned with PAF is higher then threshold -> Valid Pair
                    if ( len(np.where(paf_scores > paf_score_th)[0]) / n_interp_samples ) > conf_th :
                        if avg_paf_score > maxScore:
                            max_j = j
                            maxScore = avg_paf_score
                            found = 1

                # Append the connection to the list
                if found:
                    valid_pair = np.append(valid_pair, [[candA[i][3], candB[max_j][3], maxScore]], axis=0)



            # Append the detected connections to the global list
            valid_pairs.append(valid_pair)
        else: # If no keypoints are detected            
            print("No Connection : k = {}".format(k))
            invalid_pairs.append(k)
            valid_pairs.append([])
            
    # print('=============valied-pairs======={}'.format(valid_pairs))
    # print('==============invalid-pairs========={}'.format(invalid_pairs))
    return valid_pairs, invalid_pairs


# 供內部調用的函數
# This function creates a list of keypoints belonging to each person
# For each detected valid pair, it assigns the joint(s) to a person
def getPersonwiseKeypoints(valid_pairs, invalid_pairs, mapIdx, POSE_PAIRS, keypoints_list):

    # the last number in each row is the overall score
    personwiseKeypoints = -1 * np.ones((0, 19))

    for k in range(len(mapIdx)):
        if k not in invalid_pairs:
            partAs = valid_pairs[k][:,0]
            partBs = valid_pairs[k][:,1]
            indexA, indexB = np.array(POSE_PAIRS[k])

            for i in range(len(valid_pairs[k])):
                found = 0
                person_idx = -1
                for j in range(len(personwiseKeypoints)):
                    if personwiseKeypoints[j][indexA] == partAs[i]:
                        person_idx = j
                        found = 1
                        break

                if found:
                    personwiseKeypoints[person_idx][indexB] = partBs[i]
                    personwiseKeypoints[person_idx][-1] += keypoints_list[partBs[i].astype(int), 2] + valid_pairs[k][i][2]


                # if find no partA in the subset, create a new subset
                elif not found and k < 17:
                    row = -1 * np.ones(19)
                    row[indexA] = partAs[i]
                    row[indexB] = partBs[i]
                    # add the keypoint_scores for the two keypoints and the paf_score

                    row[-1] = sum(keypoints_list[valid_pairs[k][i,:2].astype(int), 2]) + valid_pairs[k][i][2]
                    personwiseKeypoints = np.vstack([personwiseKeypoints, row])

    # print('===========personwisekeypoints=========={}'.format(personwiseKeypoints))
    return personwiseKeypoints


# 供外部調用的主要函數
def humanPoseDetector(img):
    """
    input: one image(contain just one person) to detect the human pose
    output: the image whose size is changed and pose is drawed and the location of keypoints that are detected 
            and the valied pairs
    """
    # 讀取神經網絡
    protoFile = "./weights/pose_deploy_linevec.prototxt"
    weightsFile = "./weights/pose_iter_440000.caffemodel"

    nPoints = 18

    # COCO Output Format
    keypointsMapping = ['Nose', 'Neck', 'R-Sho', 'R-Elb', 'R-Wr', 'L-Sho', 'L-Elb', 'L-Wr', 'R-Hip', 
                    'R-Knee', 'R-Ank', 'L-Hip', 'L-Knee', 'L-Ank', 'R-Eye', 'L-Eye', 'R-Ear', 'L-Ear']

    POSE_PAIRS = [[1,2], [1,5], [2,3], [3,4], [5,6], [6,7],
                  [1,8], [8,9], [9,10], [1,11], [11,12], [12,13],
                  [1,0], [0,14], [14,16], [0,15], [15,17],
                  [2,17], [5,16] ]


    # index of pafs correspoding to the POSE_PAIRS
    # e.g for POSE_PAIR(1,2), the PAFs are located at indices (31,32) of output, Similarly, (1,5) -> (39,40) and so on.
    mapIdx = [[31,32], [39,40], [33,34], [35,36], [41,42], [43,44],
              [19,20], [21,22], [23,24], [25,26], [27,28], [29,30],
              [47,48], [49,50], [53,54], [51,52], [55,56],
              [37,38], [45,46]]


    colors = [ [0,100,255], [0,100,255], [0,255,255], [0,100,255], [0,255,255], [0,100,255],
             [0,255,0], [255,200,100], [255,0,255], [0,255,0], [255,200,100], [255,0,255],
             [0,0,255], [255,0,0], [200,200,0], [255,0,0], [200,200,0], [0,0,0]]
    
    
    frameWidth = img.shape[1]
    frameHeight = img.shape[0]

    t = time.time()
    net = cv2.dnn.readNetFromCaffe(protoFile, weightsFile)

    # 調整輸入高度,並根據圖像縱橫比改變輸入寬度
    inHeight = 368
    inWidth = int((inHeight/frameHeight)*frameWidth)

    inpBlob = cv2.dnn.blobFromImage(img, 1.0 / 255, (inWidth, inHeight),
                          (0, 0, 0), swapRB=False, crop=False)

    # 向前通過網絡
    net.setInput(inpBlob)
    output = net.forward()
    print("Time Taken in forward pass = {}".format(time.time() - t))

    detected_keypoints = []
    keypoints_list = np.zeros((0,3))
    keypoint_id = 0
    threshold = 0.1
    keypoints_location = []


    for part in range(nPoints):
        probMap = output[0,part,:,:]
        probMap = cv2.resize(probMap, (img.shape[1], img.shape[0]))
        keypoints = getKeypoints(probMap, threshold)      #此處必須把keypoints_location變量放在前面,因爲其沒有默認值,有默認值的不能放在最前面
        #keypoints_temp = list(keypoints[0])
        # 將所有關鍵點的座標存放在一個列表裏,爲一個二維列表,每一元素爲一含有三個元素的列表,分別爲座標和編號
        if keypoints != []:
            keypoints_temp = list(keypoints[0])
            keypoints_temp[2] = part
            keypoints_location.append(keypoints_temp)                   # 刪除每一個點座標的第三個置信度,將其變爲對應的關節點的編號
        else:
            keypoints_location.append(keypoints)           # 如果沒有檢測到,直接補空列表
        
        print("Keypoints - {} : {}".format(keypointsMapping[part], keypoints))
        keypoints_with_id = []
        for i in range(len(keypoints)):
            keypoints_with_id.append(keypoints[i] + (keypoint_id,))
            keypoints_list = np.vstack([keypoints_list, keypoints[i]])
            keypoint_id += 1

        detected_keypoints.append(keypoints_with_id)

    keypointsImg = img.copy()

    for i in range(nPoints):
        for j in range(len(detected_keypoints[i])):
            cv2.circle(keypointsImg, detected_keypoints[i][j][0:2], 5, colors[i], -1, cv2.LINE_AA)

    #cv2.imshow("Keypoints",frameClone)

    valid_pairs, invalid_pairs = getValidPairs(output, mapIdx, frameWidth, frameHeight, POSE_PAIRS, detected_keypoints)
    personwiseKeypoints = getPersonwiseKeypoints(valid_pairs, invalid_pairs, mapIdx, POSE_PAIRS, keypoints_list)
    
    lineImg = keypointsImg.copy()


    for i in range(17):
        for n in range(len(personwiseKeypoints)):
            index = personwiseKeypoints[n][np.array(POSE_PAIRS[i])]
            if -1 in index:
                continue

            B = np.int32(keypoints_list[index.astype(int), 0])
            A = np.int32(keypoints_list[index.astype(int), 1])
            cv2.line(lineImg, (B[0], A[0]), (B[1], A[1]), colors[i], 3, cv2.LINE_AA)

    # 添加計時
    t, _ = net.getPerfProfile()
    freq = cv2.getTickFrequency() / 1000
    # cv2.putText(lineImg, '%.2fms' % (t / freq), (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255))
    
    return keypointsImg, lineImg, keypoints_location, valid_pairs, personwiseKeypoints, keypoints_list
    # return keypointsImg, lineImg, keypoints_location, valid_pairs

    #cv2.imshow("Detected Pose" , frameClone)
    #cv2.waitKey(0)
    #cv2.destroyAllWindows()

其中最後一個函數返回的參數經常要做修改,所以,在上面的例子中返回的參數比較多。分別是:

  • keypointsImg:在原圖像上檢測並標記處關鍵點的圖像
  • lineImg:已經畫出火柴人的圖像
  • keypoints_location:關鍵點的座標
  • valid_pairs:因爲不是所有的點都可以檢測到,這裏記錄有有效連接的點的組合
  • personwiseKeypoints:其實是關鍵點的另一種保存方式,主要爲後面提取火柴人用
  • keypoints_list:基本上同上

3.各種情況

3.1 直接檢測多人(圖像)

"""
簡單的進行檢測一張圖像裏的所有人的姿態
"""

from HumanPoseDetecte import humanPoseDetector

import cv2


PATH = 'data/test.jpg'

img = cv2.imread(PATH)

keypointsImg, lineImg, keypoints_location, valid_pairs,_,_ = humanPoseDetector(img)

print(keypoints_location)
print(valid_pairs)


cv2.imshow('1', keypointsImg)
cv2.imshow('2', lineImg)
cv2.waitKey(0)
cv2.destroyAllWindows()

cv2.imwrite('output/test_keypoints.jpg', imgClone)
cv2.imwrite('output/test_out.jpg', imgClone_new)

3.2 直接檢測多人(視頻)

"""
針對視頻進行姿態檢測的程序
不過也只是簡單的檢測每一幀的所有人
"""

import cv2
import os
import time
from HumanPoseDetecte import humanPoseDetector

path = './data/video/'

outPath = './output/video'
if not os.path.exists(outPath):
    os.mkdir(outPath)

# 檢測一個視頻的函數
def run(video_path):
    video_name = video_path.split('/')[-1].split('.')[0]
    out_video_path = 'output/video/' + video_name + '-out.mp4'

    cap = cv2.VideoCapture(video_path)              # 讀取視頻
    input_fps = cap.get(cv2.CAP_PROP_FPS)           # 幀率

    video_frame_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # 幀數

    video_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    video_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    output_fps = int(input_fps)

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(out_video_path, fourcc, output_fps, (video_width, video_height))


    count = 0                   # 用來計數,顯示進度
    while(cap.isOpened()):
        ret_val, frame = cap.read()

        if not ret_val:
            break
        else:
            if count % 100 == 0:
                print('{} / {} have done'.format(count, video_frame_num))
            count += 1

            start = time.time()
            canvas = frame.copy()
            keypoints_img, line_img, _,_,_,_ = humanPoseDetector(canvas)

            finish = time.time()
            cv2.putText(line_img, "FPS:%f" %(1. / (finish-start)), (10, 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
            out.write(line_img)

            cv2.imshow('img', line_img)
            if cv2.waitKey(1) == ord('q'):
                break
    cap.release()
    out.release()
    cv2.destroyAllWindows()


video_names = os.listdir(path)

first = path + video_names[0]

run(first)

3.3 檢測單人(圖像)

因爲檢測單人需要先進行目標檢測,提取出單人的圖像,再進行姿態檢測。目標檢測部分的代碼移步第四部分。

4.目標檢測代碼

任何一個目標檢測的算法都可以,這裏只是一個例子。使用的是yolo。也是很久之前fork一個存儲庫的,同樣忘了是哪一個,日後找到的話,會補上鍊接。

4.1 目標檢測

先來看一下代碼的結構
在這裏插入圖片描述

4.1.1 主函數內的代碼

import cv2
import numpy as np
import time


from torchvision import transforms
import cv2
import math
import time
import torch
import numpy as np
from utils.utils import *
from utils.datasets import *
from yolo_models import *
from face_models import Resnet50FaceModel, Resnet18FaceModel

# 下面的三個庫是自己做的,分別爲openpose和目標提取,以及角度計算
from openpose.HumanPoseDetecte import humanPoseDetector
from extractTarget.extractRectangleTarget import extractRectangleTarget
from calculateangle.calculateAngle import calculatekeypointsAngle

# tracker
class Tracker(object):
    def __init__(self):

        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


        """
        human&face detection
        """
        self.boxSize = 384
        self.yolov3 = "./cfg/yolov3orihf.cfg"
        self.dataConfigPath = "cfg/coco.data"
        self.weightsPath_hf = "weights/latest_h_f.pt"
        self.confThres = 0.5
        self.nmsThres = 0.45
        self.dataConfig = parse_data_config(self.dataConfigPath)
        self.classes = load_classes(self.dataConfig['names'])

        """
        indentification
        """
        self.weightsPath_c = "./weights/res18_aug_market_cuhk.pth.tar"
        self.suspected_bbx = []
        self.infer_shape = (96, 128)
        # replay embedded vector buffer: store 10 timestep of embedded vector of target
        self.target_vector_buffer = np.zeros((10, 512))
        self.target_bbx = np.array([])
        self.bufferSize = 10
        self.bufferPointer = 0
        self.counter = 0      # 原始值爲0
        self.way2 = True


    def getCenterModel(self):

        # model = Resnet50FaceModel
        model = Resnet18FaceModel
        model = model(False).to(self.device)
        checkpoint = torch.load(self.weightsPath_c)
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        model.eval()

        return model

    def getHFDModel(self):

        model = Darknet(self.yolov3, self.boxSize)
        model.load_state_dict(torch.load(self.weightsPath_hf)['model'])
        model.to(self.device).eval()

        return model

    def getPoseModel(self):

        model = cascaded_pose_net_dev.PoseModel(cfg_path=self.yoloBase)
        model.load_state_dict(torch.load(self.weightsPath))
        # model = torch.nn.DataParallel(model)
        model.to(self.device).eval()

        return model

    def normalization(self, img, resize=False):
        if resize:
            # print(img.shape)
            h, w = img.shape[:2]
            img = cv2.resize(img, (0,0), fx=self.infer_shape[0]/w, fy=self.infer_shape[1]/h, interpolation=cv2.INTER_CUBIC)

        return img.astype(np.float32) / 255.

    def resizeRequested(self, img, height=96, width=96):

        height_, width_ = img.shape[:2]
        return cv2.resize(img, (0,0), fx=width/width_, fy=height/height_, interpolation=cv2.INTER_CUBIC)

    def iou_fillter(self):
        """Compute IoU between detect box and gt boxes

            Parameters:
            ----------
            box: numpy array , shape (4, ): x1, y1, x2, y2
                input box
            boxes: numpy array, shape (n, 4): x1, y1, x2, y2
                input ground truth boxes
        """
        # box = (x1, y1, x2, y2)
        box = self.target_bbx[:]
        # print(box)
        boxes = np.array(self.suspected_bbx)
        if len(boxes) == 0 or len(box) == 0:
            return
        # print(boxes)
        box_area = (box[2] - box[0] + 1) * (box[3] - box[1] + 1)
        area = (boxes[:, 2] - boxes[:, 0] + 1) * (boxes[:, 3] - boxes[:, 1] + 1)

        # abtain the offset of the interception of union between crop_box and gt_box
        xx1 = np.maximum(box[0], boxes[:, 0])
        yy1 = np.maximum(box[1], boxes[:, 1])
        xx2 = np.minimum(box[2], boxes[:, 2])
        yy2 = np.minimum(box[3], boxes[:, 3])

        # compute the width and height of the bounding box
        w = np.maximum(0, xx2 - xx1 + 1)
        h = np.maximum(0, yy2 - yy1 + 1)

        inter = w * h
        ovr = inter / (box_area + area - inter)
        # select ovr > 0.4
        thre_ovr_idex = np.where(ovr > 0.4)
        # update boxes
        u_boxes = boxes[thre_ovr_idex]
        # update ovr
        ovr = ovr[thre_ovr_idex]

        if len(u_boxes) > 3:
            # return the top3 ovr index
            top3_index = np.argsort(ovr)[-3:]
            self.suspected_bbx = u_boxes[top3_index]
        elif len(u_boxes) == 1:
            self.suspected_bbx = u_boxes
        elif len(u_boxes) == 0:
            # 鏡頭突然切換,iou爲0,對所有預測框篩選,得出目標
            # 目標原先的bbx失去跟蹤意義,清空
            self.way2 = True
            self.target_bbx = np.array([])
            self.suspected_bbx = boxes
        # print(self.suspected_bbx)
    
    def indentification(self, img, canvas, model, query):
        # print('using indetification')
        """
        返回的參數增加了location,即用來框住目標的矩形的四個角的座標
        """

        imgs = []
        ori = img
        location = []
        
        if self.counter != 0:
            self.iou_fillter()
        # print('--------------------3-------------')
        if self.counter == 0:
            query_img = cv2.imread(query)
            #query_img = query
            query_img = self.normalization(query_img, resize=True)

            query_img = torch.from_numpy(query_img.transpose(2, 0, 1)).unsqueeze(0)
            query_img = query_img.to(self.device)
            _, embeddings = model(query_img)
            embeddings = embeddings.cpu().detach().numpy()
            self.target_vector_buffer[self.bufferPointer, :] = embeddings
            self.bufferPointer += 1

            # self.target_bbx = np.append(self.target_bbx, self.suspected_bbx[0])
            self.counter = 1
            # print('---------------------5------------------')
        else:

            for bbx in self.suspected_bbx:
                img = ori[int(bbx[1]):int(bbx[3]), int(bbx[0]):int(bbx[2]), :]
                img = self.normalization(img, resize=True)

                img = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0)
                imgs.append(img)
                # img = self.transform_for_infer(self.infer_shape)(img)
                # imgs.append(img.unsqueeze(0))

            if len(imgs) != 0:
                imgs = torch.cat(imgs, 0)
                imgs = imgs.to(self.device)
                # print(imgs.shape)
                # tic = time.time()
                _, embeddings = model(imgs)
                # toc = time.time()
                # print(toc-tic)
                embeddings = embeddings.cpu().detach().numpy() # (3, 512)

                distance = np.zeros((1, len(self.suspected_bbx))) # (1, 3) 3--bbox 10--vector buffer
                if self.bufferPointer < 19:
                    for i in range(self.bufferPointer):
                        distance += np.sum((embeddings - np.expand_dims(self.target_vector_buffer[i, :], axis=0))**2, axis=1)
                    distance /= self.bufferPointer
                else:
                    for i in range(self.bufferSize):
                        distance += np.sum((embeddings - np.expand_dims(self.target_vector_buffer[i, :], axis=0))**2, axis=1)
                    distance /= self.bufferSize

                # distance = np.squeeze(distance)
                print(distance)
                # print('-----------------4---------------')

                # 1. 設定閾值 < 0.4
                # index = np.where(distance < 0.4)
                # 2. 找到空間距離最小的bbox
                index = np.argmin(distance[0])
                if self.way2:
                    if distance[0][index] < 0.6:
                        if self.bufferPointer > 9:
                            self.bufferPointer = 0

                        self.target_vector_buffer[self.bufferPointer, :] = embeddings[index, :]
                        self.bufferPointer += 1

                        x1, y1, x2, y2 = self.suspected_bbx[index]
                        # 更新target的bbx
                        # print(self.target_bbx)
                        # print(self.suspected_bbx[index])
                        self.target_bbx = self.suspected_bbx[index]
                        label = 'Target %f' % distance[0][index]
                        plot_one_box([x1, y1, x2, y2], canvas, label=label, color=(0, 255, 170))
                        self.way2 = False
                        location = [x1, y1, x2, y2]
                else:
                    # print('-----------------6------------------')
                    if distance[0][index] < 0.4:
                        if self.bufferPointer > 9:
                            self.bufferPointer = 0
                        self.target_vector_buffer[self.bufferPointer, :] = embeddings[index, :]
                        self.bufferPointer += 1

                        x1, y1, x2, y2 = self.suspected_bbx[index]
                        # 更新target的bbx
                        # print(self.target_bbx)
                        # print(self.suspected_bbx[index])
                        self.target_bbx = self.suspected_bbx[index]
                        label = 'Target %f'%distance[0][index]
                        plot_one_box([x1, y1, x2, y2], canvas, label=label, color=(0, 255, 170))
                        # print('-------------------------7--------------------')
                        location = [x1, y1, x2, y2]

        return canvas, location

    
    def humanFaceDetector(self, img, canvas, model):
        # print('using humanFaceDetector\n')

        ori = img
        img, _, _, _ = resize_square(img, height=self.boxSize, color=(127.5, 127.5, 127.5))
        img = img[:, :, ::-1].transpose(2, 0, 1)
        img = np.ascontiguousarray(img, dtype=np.float32)
        img = self.normalization(img)
        img = torch.from_numpy(img).unsqueeze(0).to(self.device)
        # print('-----------------1------------------')

        img_detections = []
        with torch.no_grad():
            pred = model(img)
            pred = pred[pred[:, :, 4] > self.confThres]

            if len(pred) > 0:
                detections = non_max_suppression(pred.unsqueeze(0), self.confThres, self.nmsThres)
                img_detections.extend(detections)
            else:
                detections = np.array([])
        # print('----------------------2--------------------')

        if len(detections) != 0:

            # The amount of padding that was added
            pad_x = max(ori.shape[0] - ori.shape[1], 0) * (self.boxSize / max(ori.shape))
            pad_y = max(ori.shape[1] - ori.shape[0], 0) * (self.boxSize / max(ori.shape))
            # Image height and width after padding is removed
            unpad_h = self.boxSize - pad_y
            unpad_w = self.boxSize - pad_x

            for x1, y1, x2, y2, conf, cls_conf, cls_pred in detections[0]:

                # Rescale coordinates to original dimensions
                box_h = ((y2 - y1) / unpad_h) * ori.shape[0]
                box_w = ((x2 - x1) / unpad_w) * ori.shape[1]
                y1 = (((y1 - pad_y // 2) / unpad_h) * ori.shape[0]).round().item()
                x1 = (((x1 - pad_x // 2) / unpad_w) * ori.shape[1]).round().item()
                x2 = (x1 + box_w).round().item()
                y2 = (y1 + box_h).round().item()
                x1, y1, x2, y2 = max(x1, 0), max(y1, 0), max(x2, 0), max(y2, 0)

                label = '%s %.2f' % (self.classes[int(cls_pred)], conf)
                color = [(255, 85, 0), (0, 255, 170)]

                if int(cls_pred) == 0:
                    self.suspected_bbx.append([x1, y1, x2, y2])
                    # plot_one_box([x1, y1, x2, y2], canvas, label=label, color=color[int(cls_pred)])
                # else:
                #     plot_one_box([x1, y1, x2, y2], canvas, label=label, color=color[int(cls_pred)])


        return canvas

4.1.2 其他直接使用代碼文件的代碼

其他還有一些代碼,直接按照路徑保存到合適位置就可以了。就不貼上來了,之後會傳上來。

4.1.3 提取目標代碼

# -*- coding: utf-8 -*-
"""
Created on Sat May 25 13:07:55 2019

@author: wangw
"""

"""
用於第一步框選出目標之後,提取矩形區域
暫時只提取出矩形區域即可
後期處理視頻的話,由於每一幀的目標矩形大小不確定,所以可能需要建一個稍微大的全黑圖像,將提取出的圖像放在其中
"""

import cv2
import numpy as np

def extractRectangleTarget(img, location):
    """
    圖像在存儲時是一個二維矩陣,其上某一像素點的座標爲(列,行)
    """
    # 矩形位置
    x1, x2, y1, y2 = int(location[0]), int(location[2]), int(location[1]), int(location[3])
    extract_target_img = img[y1:y2, x1:x2]
    
    return extract_target_img
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章