【fork】openpose检测的多种情况(多人,单人,图像,视频,计算角度)

1.来源

openpose检测的部分代码和模型来自一个存储库,具体的连接忘了,日后找到的话补上。

另外,单纯的姿态检测的话不需要深度学习框架,只要有opencv即可。如果要单人检测,就需要用到目标检测算法,使用的是yolo,需要用到pytorch。

2.使用openpose检测的代码

# -*- coding: utf-8 -*-
"""
Created on Fri May 24 23:18:36 2019

@author: wangwei
"""

import cv2
import time
import numpy as np
from random import randint

# 供内部调用的函数
def getKeypoints(probMap, threshold=0.1):
    
    mapSmooth = cv2.GaussianBlur(probMap, (3,3), 0, 0)
    mapMask = np.uint8(mapSmooth > threshold)
    keypoints = []

    #find the blobs
    # 可能会遇到opencv版本不对的问题,导致下面的函数返回值不一样,删除第一个下划线就行了
    _, contours, _ = cv2.findContours(mapMask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

    #for each blob find the maxima
    for cnt in contours:
        blobMask = np.zeros(mapMask.shape)
        blobMask = cv2.fillConvexPoly(blobMask, cnt, 1)
        maskedProbMap = mapSmooth * blobMask
        _, maxVal, _, maxLoc = cv2.minMaxLoc(maskedProbMap)
        keypoints.append(maxLoc + (probMap[maxLoc[1], maxLoc[0]],))
        keypoints_temp = keypoints
        
        
    #print('===========keypoints========{}'.format(keypoints))
    
    return keypoints


# 供内部调用的函数
# Find valid connections between the different joints of a all persons present
def getValidPairs(output, mapIdx, frameWidth, frameHeight, POSE_PAIRS, detected_keypoints):

    valid_pairs = []
    invalid_pairs = []
    n_interp_samples = 10
    paf_score_th = 0.1
    conf_th = 0.7
    
    # loop for every POSE_PAIR
    for k in range(len(mapIdx)):
        # A->B constitute a limb
        pafA = output[0, mapIdx[k][0], :, :]
        pafB = output[0, mapIdx[k][1], :, :]
        pafA = cv2.resize(pafA, (frameWidth, frameHeight))
        pafB = cv2.resize(pafB, (frameWidth, frameHeight))

        # Find the keypoints for the first and second limb
        candA = detected_keypoints[POSE_PAIRS[k][0]]
        candB = detected_keypoints[POSE_PAIRS[k][1]]
        nA = len(candA)
        nB = len(candB)


        # If keypoints for the joint-pair is detected
        # check every joint in candA with every joint in candB
        # Calculate the distance vector between the two joints
        # Find the PAF values at a set of interpolated points between the joints
        # Use the above formula to compute a score to mark the connection valid


        if( nA != 0 and nB != 0):
            valid_pair = np.zeros((0,3))
            for i in range(nA):
                max_j=-1
                maxScore = -1
                found = 0
                for j in range(nB):
                    # Find d_ij
                    d_ij = np.subtract(candB[j][:2], candA[i][:2])
                    norm = np.linalg.norm(d_ij)
                    if norm:
                        d_ij = d_ij / norm
                    else:
                        continue
                        
                    # Find p(u)
                    interp_coord = list(zip(np.linspace(candA[i][0], candB[j][0], num=n_interp_samples),
                                            np.linspace(candA[i][1], candB[j][1], num=n_interp_samples)))

                    # Find L(p(u))
                    paf_interp = []
                    for k in range(len(interp_coord)):
                        paf_interp.append([pafA[int(round(interp_coord[k][1])), int(round(interp_coord[k][0]))],
                                           pafB[int(round(interp_coord[k][1])), int(round(interp_coord[k][0]))] ])

                    # Find E
                    paf_scores = np.dot(paf_interp, d_ij)
                    avg_paf_score = sum(paf_scores)/len(paf_scores)

                    # Check if the connection is valid
                    # If the fraction of interpolated vectors aligned with PAF is higher then threshold -> Valid Pair
                    if ( len(np.where(paf_scores > paf_score_th)[0]) / n_interp_samples ) > conf_th :
                        if avg_paf_score > maxScore:
                            max_j = j
                            maxScore = avg_paf_score
                            found = 1

                # Append the connection to the list
                if found:
                    valid_pair = np.append(valid_pair, [[candA[i][3], candB[max_j][3], maxScore]], axis=0)



            # Append the detected connections to the global list
            valid_pairs.append(valid_pair)
        else: # If no keypoints are detected            
            print("No Connection : k = {}".format(k))
            invalid_pairs.append(k)
            valid_pairs.append([])
            
    # print('=============valied-pairs======={}'.format(valid_pairs))
    # print('==============invalid-pairs========={}'.format(invalid_pairs))
    return valid_pairs, invalid_pairs


# 供内部调用的函数
# This function creates a list of keypoints belonging to each person
# For each detected valid pair, it assigns the joint(s) to a person
def getPersonwiseKeypoints(valid_pairs, invalid_pairs, mapIdx, POSE_PAIRS, keypoints_list):

    # the last number in each row is the overall score
    personwiseKeypoints = -1 * np.ones((0, 19))

    for k in range(len(mapIdx)):
        if k not in invalid_pairs:
            partAs = valid_pairs[k][:,0]
            partBs = valid_pairs[k][:,1]
            indexA, indexB = np.array(POSE_PAIRS[k])

            for i in range(len(valid_pairs[k])):
                found = 0
                person_idx = -1
                for j in range(len(personwiseKeypoints)):
                    if personwiseKeypoints[j][indexA] == partAs[i]:
                        person_idx = j
                        found = 1
                        break

                if found:
                    personwiseKeypoints[person_idx][indexB] = partBs[i]
                    personwiseKeypoints[person_idx][-1] += keypoints_list[partBs[i].astype(int), 2] + valid_pairs[k][i][2]


                # if find no partA in the subset, create a new subset
                elif not found and k < 17:
                    row = -1 * np.ones(19)
                    row[indexA] = partAs[i]
                    row[indexB] = partBs[i]
                    # add the keypoint_scores for the two keypoints and the paf_score

                    row[-1] = sum(keypoints_list[valid_pairs[k][i,:2].astype(int), 2]) + valid_pairs[k][i][2]
                    personwiseKeypoints = np.vstack([personwiseKeypoints, row])

    # print('===========personwisekeypoints=========={}'.format(personwiseKeypoints))
    return personwiseKeypoints


# 供外部调用的主要函数
def humanPoseDetector(img):
    """
    input: one image(contain just one person) to detect the human pose
    output: the image whose size is changed and pose is drawed and the location of keypoints that are detected 
            and the valied pairs
    """
    # 读取神经网络
    protoFile = "./weights/pose_deploy_linevec.prototxt"
    weightsFile = "./weights/pose_iter_440000.caffemodel"

    nPoints = 18

    # COCO Output Format
    keypointsMapping = ['Nose', 'Neck', 'R-Sho', 'R-Elb', 'R-Wr', 'L-Sho', 'L-Elb', 'L-Wr', 'R-Hip', 
                    'R-Knee', 'R-Ank', 'L-Hip', 'L-Knee', 'L-Ank', 'R-Eye', 'L-Eye', 'R-Ear', 'L-Ear']

    POSE_PAIRS = [[1,2], [1,5], [2,3], [3,4], [5,6], [6,7],
                  [1,8], [8,9], [9,10], [1,11], [11,12], [12,13],
                  [1,0], [0,14], [14,16], [0,15], [15,17],
                  [2,17], [5,16] ]


    # index of pafs correspoding to the POSE_PAIRS
    # e.g for POSE_PAIR(1,2), the PAFs are located at indices (31,32) of output, Similarly, (1,5) -> (39,40) and so on.
    mapIdx = [[31,32], [39,40], [33,34], [35,36], [41,42], [43,44],
              [19,20], [21,22], [23,24], [25,26], [27,28], [29,30],
              [47,48], [49,50], [53,54], [51,52], [55,56],
              [37,38], [45,46]]


    colors = [ [0,100,255], [0,100,255], [0,255,255], [0,100,255], [0,255,255], [0,100,255],
             [0,255,0], [255,200,100], [255,0,255], [0,255,0], [255,200,100], [255,0,255],
             [0,0,255], [255,0,0], [200,200,0], [255,0,0], [200,200,0], [0,0,0]]
    
    
    frameWidth = img.shape[1]
    frameHeight = img.shape[0]

    t = time.time()
    net = cv2.dnn.readNetFromCaffe(protoFile, weightsFile)

    # 调整输入高度,并根据图像纵横比改变输入宽度
    inHeight = 368
    inWidth = int((inHeight/frameHeight)*frameWidth)

    inpBlob = cv2.dnn.blobFromImage(img, 1.0 / 255, (inWidth, inHeight),
                          (0, 0, 0), swapRB=False, crop=False)

    # 向前通过网络
    net.setInput(inpBlob)
    output = net.forward()
    print("Time Taken in forward pass = {}".format(time.time() - t))

    detected_keypoints = []
    keypoints_list = np.zeros((0,3))
    keypoint_id = 0
    threshold = 0.1
    keypoints_location = []


    for part in range(nPoints):
        probMap = output[0,part,:,:]
        probMap = cv2.resize(probMap, (img.shape[1], img.shape[0]))
        keypoints = getKeypoints(probMap, threshold)      #此处必须把keypoints_location变量放在前面,因为其没有默认值,有默认值的不能放在最前面
        #keypoints_temp = list(keypoints[0])
        # 将所有关键点的座标存放在一个列表里,为一个二维列表,每一元素为一含有三个元素的列表,分别为座标和编号
        if keypoints != []:
            keypoints_temp = list(keypoints[0])
            keypoints_temp[2] = part
            keypoints_location.append(keypoints_temp)                   # 删除每一个点座标的第三个置信度,将其变为对应的关节点的编号
        else:
            keypoints_location.append(keypoints)           # 如果没有检测到,直接补空列表
        
        print("Keypoints - {} : {}".format(keypointsMapping[part], keypoints))
        keypoints_with_id = []
        for i in range(len(keypoints)):
            keypoints_with_id.append(keypoints[i] + (keypoint_id,))
            keypoints_list = np.vstack([keypoints_list, keypoints[i]])
            keypoint_id += 1

        detected_keypoints.append(keypoints_with_id)

    keypointsImg = img.copy()

    for i in range(nPoints):
        for j in range(len(detected_keypoints[i])):
            cv2.circle(keypointsImg, detected_keypoints[i][j][0:2], 5, colors[i], -1, cv2.LINE_AA)

    #cv2.imshow("Keypoints",frameClone)

    valid_pairs, invalid_pairs = getValidPairs(output, mapIdx, frameWidth, frameHeight, POSE_PAIRS, detected_keypoints)
    personwiseKeypoints = getPersonwiseKeypoints(valid_pairs, invalid_pairs, mapIdx, POSE_PAIRS, keypoints_list)
    
    lineImg = keypointsImg.copy()


    for i in range(17):
        for n in range(len(personwiseKeypoints)):
            index = personwiseKeypoints[n][np.array(POSE_PAIRS[i])]
            if -1 in index:
                continue

            B = np.int32(keypoints_list[index.astype(int), 0])
            A = np.int32(keypoints_list[index.astype(int), 1])
            cv2.line(lineImg, (B[0], A[0]), (B[1], A[1]), colors[i], 3, cv2.LINE_AA)

    # 添加计时
    t, _ = net.getPerfProfile()
    freq = cv2.getTickFrequency() / 1000
    # cv2.putText(lineImg, '%.2fms' % (t / freq), (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255))
    
    return keypointsImg, lineImg, keypoints_location, valid_pairs, personwiseKeypoints, keypoints_list
    # return keypointsImg, lineImg, keypoints_location, valid_pairs

    #cv2.imshow("Detected Pose" , frameClone)
    #cv2.waitKey(0)
    #cv2.destroyAllWindows()

其中最后一个函数返回的参数经常要做修改,所以,在上面的例子中返回的参数比较多。分别是:

  • keypointsImg:在原图像上检测并标记处关键点的图像
  • lineImg:已经画出火柴人的图像
  • keypoints_location:关键点的座标
  • valid_pairs:因为不是所有的点都可以检测到,这里记录有有效连接的点的组合
  • personwiseKeypoints:其实是关键点的另一种保存方式,主要为后面提取火柴人用
  • keypoints_list:基本上同上

3.各种情况

3.1 直接检测多人(图像)

"""
简单的进行检测一张图像里的所有人的姿态
"""

from HumanPoseDetecte import humanPoseDetector

import cv2


PATH = 'data/test.jpg'

img = cv2.imread(PATH)

keypointsImg, lineImg, keypoints_location, valid_pairs,_,_ = humanPoseDetector(img)

print(keypoints_location)
print(valid_pairs)


cv2.imshow('1', keypointsImg)
cv2.imshow('2', lineImg)
cv2.waitKey(0)
cv2.destroyAllWindows()

cv2.imwrite('output/test_keypoints.jpg', imgClone)
cv2.imwrite('output/test_out.jpg', imgClone_new)

3.2 直接检测多人(视频)

"""
针对视频进行姿态检测的程序
不过也只是简单的检测每一帧的所有人
"""

import cv2
import os
import time
from HumanPoseDetecte import humanPoseDetector

path = './data/video/'

outPath = './output/video'
if not os.path.exists(outPath):
    os.mkdir(outPath)

# 检测一个视频的函数
def run(video_path):
    video_name = video_path.split('/')[-1].split('.')[0]
    out_video_path = 'output/video/' + video_name + '-out.mp4'

    cap = cv2.VideoCapture(video_path)              # 读取视频
    input_fps = cap.get(cv2.CAP_PROP_FPS)           # 帧率

    video_frame_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # 帧数

    video_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    video_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    output_fps = int(input_fps)

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(out_video_path, fourcc, output_fps, (video_width, video_height))


    count = 0                   # 用来计数,显示进度
    while(cap.isOpened()):
        ret_val, frame = cap.read()

        if not ret_val:
            break
        else:
            if count % 100 == 0:
                print('{} / {} have done'.format(count, video_frame_num))
            count += 1

            start = time.time()
            canvas = frame.copy()
            keypoints_img, line_img, _,_,_,_ = humanPoseDetector(canvas)

            finish = time.time()
            cv2.putText(line_img, "FPS:%f" %(1. / (finish-start)), (10, 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
            out.write(line_img)

            cv2.imshow('img', line_img)
            if cv2.waitKey(1) == ord('q'):
                break
    cap.release()
    out.release()
    cv2.destroyAllWindows()


video_names = os.listdir(path)

first = path + video_names[0]

run(first)

3.3 检测单人(图像)

因为检测单人需要先进行目标检测,提取出单人的图像,再进行姿态检测。目标检测部分的代码移步第四部分。

4.目标检测代码

任何一个目标检测的算法都可以,这里只是一个例子。使用的是yolo。也是很久之前fork一个存储库的,同样忘了是哪一个,日后找到的话,会补上链接。

4.1 目标检测

先来看一下代码的结构
在这里插入图片描述

4.1.1 主函数内的代码

import cv2
import numpy as np
import time


from torchvision import transforms
import cv2
import math
import time
import torch
import numpy as np
from utils.utils import *
from utils.datasets import *
from yolo_models import *
from face_models import Resnet50FaceModel, Resnet18FaceModel

# 下面的三个库是自己做的,分别为openpose和目标提取,以及角度计算
from openpose.HumanPoseDetecte import humanPoseDetector
from extractTarget.extractRectangleTarget import extractRectangleTarget
from calculateangle.calculateAngle import calculatekeypointsAngle

# tracker
class Tracker(object):
    def __init__(self):

        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


        """
        human&face detection
        """
        self.boxSize = 384
        self.yolov3 = "./cfg/yolov3orihf.cfg"
        self.dataConfigPath = "cfg/coco.data"
        self.weightsPath_hf = "weights/latest_h_f.pt"
        self.confThres = 0.5
        self.nmsThres = 0.45
        self.dataConfig = parse_data_config(self.dataConfigPath)
        self.classes = load_classes(self.dataConfig['names'])

        """
        indentification
        """
        self.weightsPath_c = "./weights/res18_aug_market_cuhk.pth.tar"
        self.suspected_bbx = []
        self.infer_shape = (96, 128)
        # replay embedded vector buffer: store 10 timestep of embedded vector of target
        self.target_vector_buffer = np.zeros((10, 512))
        self.target_bbx = np.array([])
        self.bufferSize = 10
        self.bufferPointer = 0
        self.counter = 0      # 原始值为0
        self.way2 = True


    def getCenterModel(self):

        # model = Resnet50FaceModel
        model = Resnet18FaceModel
        model = model(False).to(self.device)
        checkpoint = torch.load(self.weightsPath_c)
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        model.eval()

        return model

    def getHFDModel(self):

        model = Darknet(self.yolov3, self.boxSize)
        model.load_state_dict(torch.load(self.weightsPath_hf)['model'])
        model.to(self.device).eval()

        return model

    def getPoseModel(self):

        model = cascaded_pose_net_dev.PoseModel(cfg_path=self.yoloBase)
        model.load_state_dict(torch.load(self.weightsPath))
        # model = torch.nn.DataParallel(model)
        model.to(self.device).eval()

        return model

    def normalization(self, img, resize=False):
        if resize:
            # print(img.shape)
            h, w = img.shape[:2]
            img = cv2.resize(img, (0,0), fx=self.infer_shape[0]/w, fy=self.infer_shape[1]/h, interpolation=cv2.INTER_CUBIC)

        return img.astype(np.float32) / 255.

    def resizeRequested(self, img, height=96, width=96):

        height_, width_ = img.shape[:2]
        return cv2.resize(img, (0,0), fx=width/width_, fy=height/height_, interpolation=cv2.INTER_CUBIC)

    def iou_fillter(self):
        """Compute IoU between detect box and gt boxes

            Parameters:
            ----------
            box: numpy array , shape (4, ): x1, y1, x2, y2
                input box
            boxes: numpy array, shape (n, 4): x1, y1, x2, y2
                input ground truth boxes
        """
        # box = (x1, y1, x2, y2)
        box = self.target_bbx[:]
        # print(box)
        boxes = np.array(self.suspected_bbx)
        if len(boxes) == 0 or len(box) == 0:
            return
        # print(boxes)
        box_area = (box[2] - box[0] + 1) * (box[3] - box[1] + 1)
        area = (boxes[:, 2] - boxes[:, 0] + 1) * (boxes[:, 3] - boxes[:, 1] + 1)

        # abtain the offset of the interception of union between crop_box and gt_box
        xx1 = np.maximum(box[0], boxes[:, 0])
        yy1 = np.maximum(box[1], boxes[:, 1])
        xx2 = np.minimum(box[2], boxes[:, 2])
        yy2 = np.minimum(box[3], boxes[:, 3])

        # compute the width and height of the bounding box
        w = np.maximum(0, xx2 - xx1 + 1)
        h = np.maximum(0, yy2 - yy1 + 1)

        inter = w * h
        ovr = inter / (box_area + area - inter)
        # select ovr > 0.4
        thre_ovr_idex = np.where(ovr > 0.4)
        # update boxes
        u_boxes = boxes[thre_ovr_idex]
        # update ovr
        ovr = ovr[thre_ovr_idex]

        if len(u_boxes) > 3:
            # return the top3 ovr index
            top3_index = np.argsort(ovr)[-3:]
            self.suspected_bbx = u_boxes[top3_index]
        elif len(u_boxes) == 1:
            self.suspected_bbx = u_boxes
        elif len(u_boxes) == 0:
            # 镜头突然切换,iou为0,对所有预测框筛选,得出目标
            # 目标原先的bbx失去跟踪意义,清空
            self.way2 = True
            self.target_bbx = np.array([])
            self.suspected_bbx = boxes
        # print(self.suspected_bbx)
    
    def indentification(self, img, canvas, model, query):
        # print('using indetification')
        """
        返回的参数增加了location,即用来框住目标的矩形的四个角的座标
        """

        imgs = []
        ori = img
        location = []
        
        if self.counter != 0:
            self.iou_fillter()
        # print('--------------------3-------------')
        if self.counter == 0:
            query_img = cv2.imread(query)
            #query_img = query
            query_img = self.normalization(query_img, resize=True)

            query_img = torch.from_numpy(query_img.transpose(2, 0, 1)).unsqueeze(0)
            query_img = query_img.to(self.device)
            _, embeddings = model(query_img)
            embeddings = embeddings.cpu().detach().numpy()
            self.target_vector_buffer[self.bufferPointer, :] = embeddings
            self.bufferPointer += 1

            # self.target_bbx = np.append(self.target_bbx, self.suspected_bbx[0])
            self.counter = 1
            # print('---------------------5------------------')
        else:

            for bbx in self.suspected_bbx:
                img = ori[int(bbx[1]):int(bbx[3]), int(bbx[0]):int(bbx[2]), :]
                img = self.normalization(img, resize=True)

                img = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0)
                imgs.append(img)
                # img = self.transform_for_infer(self.infer_shape)(img)
                # imgs.append(img.unsqueeze(0))

            if len(imgs) != 0:
                imgs = torch.cat(imgs, 0)
                imgs = imgs.to(self.device)
                # print(imgs.shape)
                # tic = time.time()
                _, embeddings = model(imgs)
                # toc = time.time()
                # print(toc-tic)
                embeddings = embeddings.cpu().detach().numpy() # (3, 512)

                distance = np.zeros((1, len(self.suspected_bbx))) # (1, 3) 3--bbox 10--vector buffer
                if self.bufferPointer < 19:
                    for i in range(self.bufferPointer):
                        distance += np.sum((embeddings - np.expand_dims(self.target_vector_buffer[i, :], axis=0))**2, axis=1)
                    distance /= self.bufferPointer
                else:
                    for i in range(self.bufferSize):
                        distance += np.sum((embeddings - np.expand_dims(self.target_vector_buffer[i, :], axis=0))**2, axis=1)
                    distance /= self.bufferSize

                # distance = np.squeeze(distance)
                print(distance)
                # print('-----------------4---------------')

                # 1. 设定阈值 < 0.4
                # index = np.where(distance < 0.4)
                # 2. 找到空间距离最小的bbox
                index = np.argmin(distance[0])
                if self.way2:
                    if distance[0][index] < 0.6:
                        if self.bufferPointer > 9:
                            self.bufferPointer = 0

                        self.target_vector_buffer[self.bufferPointer, :] = embeddings[index, :]
                        self.bufferPointer += 1

                        x1, y1, x2, y2 = self.suspected_bbx[index]
                        # 更新target的bbx
                        # print(self.target_bbx)
                        # print(self.suspected_bbx[index])
                        self.target_bbx = self.suspected_bbx[index]
                        label = 'Target %f' % distance[0][index]
                        plot_one_box([x1, y1, x2, y2], canvas, label=label, color=(0, 255, 170))
                        self.way2 = False
                        location = [x1, y1, x2, y2]
                else:
                    # print('-----------------6------------------')
                    if distance[0][index] < 0.4:
                        if self.bufferPointer > 9:
                            self.bufferPointer = 0
                        self.target_vector_buffer[self.bufferPointer, :] = embeddings[index, :]
                        self.bufferPointer += 1

                        x1, y1, x2, y2 = self.suspected_bbx[index]
                        # 更新target的bbx
                        # print(self.target_bbx)
                        # print(self.suspected_bbx[index])
                        self.target_bbx = self.suspected_bbx[index]
                        label = 'Target %f'%distance[0][index]
                        plot_one_box([x1, y1, x2, y2], canvas, label=label, color=(0, 255, 170))
                        # print('-------------------------7--------------------')
                        location = [x1, y1, x2, y2]

        return canvas, location

    
    def humanFaceDetector(self, img, canvas, model):
        # print('using humanFaceDetector\n')

        ori = img
        img, _, _, _ = resize_square(img, height=self.boxSize, color=(127.5, 127.5, 127.5))
        img = img[:, :, ::-1].transpose(2, 0, 1)
        img = np.ascontiguousarray(img, dtype=np.float32)
        img = self.normalization(img)
        img = torch.from_numpy(img).unsqueeze(0).to(self.device)
        # print('-----------------1------------------')

        img_detections = []
        with torch.no_grad():
            pred = model(img)
            pred = pred[pred[:, :, 4] > self.confThres]

            if len(pred) > 0:
                detections = non_max_suppression(pred.unsqueeze(0), self.confThres, self.nmsThres)
                img_detections.extend(detections)
            else:
                detections = np.array([])
        # print('----------------------2--------------------')

        if len(detections) != 0:

            # The amount of padding that was added
            pad_x = max(ori.shape[0] - ori.shape[1], 0) * (self.boxSize / max(ori.shape))
            pad_y = max(ori.shape[1] - ori.shape[0], 0) * (self.boxSize / max(ori.shape))
            # Image height and width after padding is removed
            unpad_h = self.boxSize - pad_y
            unpad_w = self.boxSize - pad_x

            for x1, y1, x2, y2, conf, cls_conf, cls_pred in detections[0]:

                # Rescale coordinates to original dimensions
                box_h = ((y2 - y1) / unpad_h) * ori.shape[0]
                box_w = ((x2 - x1) / unpad_w) * ori.shape[1]
                y1 = (((y1 - pad_y // 2) / unpad_h) * ori.shape[0]).round().item()
                x1 = (((x1 - pad_x // 2) / unpad_w) * ori.shape[1]).round().item()
                x2 = (x1 + box_w).round().item()
                y2 = (y1 + box_h).round().item()
                x1, y1, x2, y2 = max(x1, 0), max(y1, 0), max(x2, 0), max(y2, 0)

                label = '%s %.2f' % (self.classes[int(cls_pred)], conf)
                color = [(255, 85, 0), (0, 255, 170)]

                if int(cls_pred) == 0:
                    self.suspected_bbx.append([x1, y1, x2, y2])
                    # plot_one_box([x1, y1, x2, y2], canvas, label=label, color=color[int(cls_pred)])
                # else:
                #     plot_one_box([x1, y1, x2, y2], canvas, label=label, color=color[int(cls_pred)])


        return canvas

4.1.2 其他直接使用代码文件的代码

其他还有一些代码,直接按照路径保存到合适位置就可以了。就不贴上来了,之后会传上来。

4.1.3 提取目标代码

# -*- coding: utf-8 -*-
"""
Created on Sat May 25 13:07:55 2019

@author: wangw
"""

"""
用于第一步框选出目标之后,提取矩形区域
暂时只提取出矩形区域即可
后期处理视频的话,由于每一帧的目标矩形大小不确定,所以可能需要建一个稍微大的全黑图像,将提取出的图像放在其中
"""

import cv2
import numpy as np

def extractRectangleTarget(img, location):
    """
    图像在存储时是一个二维矩阵,其上某一像素点的座标为(列,行)
    """
    # 矩形位置
    x1, x2, y1, y2 = int(location[0]), int(location[2]), int(location[1]), int(location[3])
    extract_target_img = img[y1:y2, x1:x2]
    
    return extract_target_img
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章