YOLO後處理理論分析代碼分析

理論分析

system.png

YOLO從v2版本開始重新啓用anchor box,YOLOv2網絡的網絡輸出爲尺寸爲[b,125,13,13]的tensor,要將這個Tensor變爲最終的輸出結果,還需要以下的處理:

  • 解碼:從Tensor中解析出所有框的位置信息和類別信息
  • NMS:篩選最能表現物品的識別框

解碼過程

解碼之前,需要明確的是每個候選框需要5+class_num個數據,分別是相對位置x,y,相對寬度w,h,置信度c和class_num個分類結果,YOLOv2-voc中class_num=20,即每個格點對應5個候選框,每個候選框有5+20=25個參數,這就是爲什麼輸出Tensor的最後一維爲5*(20+5)=125。

tensor.png

上圖爲一個框所需要的所有數據構成,假設這個框是位於格點X,Y的,對應的anchor box大小爲W,H,位置相關參數的處理方法如下所示,其中,

分別是輸出Tensor在長寬上的值,這裏

分別爲原圖片的長和寬:

置信度和類別信息處理方法如下所示:

當格點置信度大於某個閾值時,認爲該格點有物體,物體類別爲class_id對應的類別

NMS

NMS爲非最大值抑制,用在YOLO系統中的含義指從多個候選框標記同一個物品時,從中選擇最合適的候選框。其基本思維很簡單:使用置信度最高的候選框標記一個物體,若其他候選框與該候選框的IOU超過一個閾值,則認爲其他候選框與該候選框標記的是同一個物體,丟棄其他候選框。

具體實現時,可以將所有候選框進行排序,置信度高的在前,置信度低的在後。從置信度高的候選框開始遍歷所有候選框,對於某一個候選框,將之後所有的候選框與其計算IOU,若IOU高於一個閾值,則丟棄置信度低的候選框。算法流程圖如下所示:

nms.png

代碼分析

這裏選擇的是marvis開源的基於Pytorch的YOLOv2代碼,其優勢在於所有的部分均使用Python實現,沒有使用Cython,無需編譯即可使用,且依賴較少,文件管理比較扁平。

解碼部分

解碼部分在utils.py文件中,由get_region_boxes函數實現。首先是準備部分,這裏首先獲取了輸出的相關信息,yolo-voc網絡下有b爲batch,預測模式下一般爲1,h=w=13。隨後reshape了輸出,其維度變爲(25,13*13*5),改變維度的目的是方便後面處理的索引。

def get_region_boxes(output, conf_thresh, num_classes, anchors, num_anchors, only_objectness=1, validation=False):    
    anchor_step = len(anchors) / num_anchors
    if output.dim() == 3:
        output = output.unsqueeze(0)
    batch = output.size(0)
    assert(output.size(1) == (5 + num_classes) * num_anchors)
    h = output.size(2)
    w = output.size(3)
    output = output.view(batch * num_anchors, 5 + num_classes, h * w).transpose(
        0, 1).contiguous().view(5 + num_classes, batch * num_anchors * h * w)
    all_boxes = []

隨後是處理x,y的部分,xs和ys就是處理後的候選框中心點相對座標,grid_x和grid_y與output[0]shape相同,分別表示對應output位置的候選框所屬的格點座標X與Y,這裏的xs和ys實現了上述公式中的

    grid_x = torch.linspace(0, w - 1, w).repeat(h, 1).repeat(batch *
                                                             num_anchors, 1, 1).view(batch * num_anchors * h * w).cuda()
    grid_y = torch.linspace(0, h - 1, h).repeat(w, 1).t().repeat(
        batch * num_anchors, 1, 1).view(batch * num_anchors * h * w).cuda()
    print("outputs shape", output.shape)
    xs = torch.sigmoid(output[0]) + grid_x
    ys = torch.sigmoid(output[1]) + grid_y

之後爲處理w,h的部分,與處理x,y的部分類似,最終ws和hs爲修正後的物品尺寸信息,實現了

。其中W和H分別爲當前anchor box的建議尺寸。

    anchor_w = torch.Tensor(anchors).view(
        num_anchors, anchor_step).index_select(1, torch.LongTensor([0]))
    anchor_h = torch.Tensor(anchors).view(
        num_anchors, anchor_step).index_select(1, torch.LongTensor([1]))
    anchor_w = anchor_w.repeat(batch, 1).repeat(
        1, 1, h * w).view(batch * num_anchors * h * w).cuda()
    anchor_h = anchor_h.repeat(batch, 1).repeat(
        1, 1, h * w).view(batch * num_anchors * h * w).cuda()
    ws = torch.exp(output[2]) * anchor_w
    hs = torch.exp(output[3]) * anchor_h

接下來是獲取置信度的部分和類別部分,獲取該anchor box的置信度爲det_confs=sigmoid(c)。隨後處理類別信息,先對類別信息對應的數據做softmax操作,隨後獲取其最大值cls_max_confs和最大值所在的位置cls_max_ids,其中位置cls_max_ids對應每個anchor box框住的“物品”的類別。

    det_confs = torch.sigmoid(output[4])
    cls_confs = torch.nn.Softmax()(
        Variable(output[5:5 + num_classes].transpose(0, 1))).data
    cls_max_confs, cls_max_ids = torch.max(cls_confs, 1)
    cls_max_confs = cls_max_confs.view(-1)
    cls_max_ids = cls_max_ids.view(-1)

隨後是一些其他的處理過程,例如獲取格點數量sz_hw,anchor box的數量sz_hwa等,函數convert2cpu是在CPU上覆制一個該數據,注意這裏是拷貝,並不是將數據從GPU轉移到CPU上。

    sz_hw = h * w
    sz_hwa = sz_hw * num_anchors
    det_confs = convert2cpu(det_confs)
    cls_max_confs = convert2cpu(cls_max_confs)
    cls_max_ids = convert2cpu_long(cls_max_ids)
    xs = convert2cpu(xs)
    ys = convert2cpu(ys)
    ws = convert2cpu(ws)
    hs = convert2cpu(hs)
    if validation:
        cls_confs = convert2cpu(cls_confs.view(-1, num_classes))

隨後是一個解碼的大循環,分析見下面的註釋

    for b in range(batch):
        boxes = []
        # boxes爲容納所有候選框的list
        for cy in range(h):
            for cx in range(w):
                for i in range(num_anchors):
                    # 遍歷每一個anchor box,這裏訪問位於格點cx,cy的第i個anchor box
                    ind = b * sz_hwa + i * sz_hw + cy * w + cx
                    # 獲取該anchor box在det_conf中對應的index
                    det_conf = det_confs[ind]
                    if only_objectness:
                        conf = det_confs[ind]
                    else:
                        conf = det_confs[ind] * cls_max_confs[ind]
                    # 處理置信度

                    if conf > conf_thresh:
                        # 若置信度大於閾值,則認爲該anchor box有效
                        bcx = xs[ind]
                        bcy = ys[ind]
                        bw = ws[ind]
                        bh = hs[ind]
                        cls_max_conf = cls_max_confs[ind]
                        cls_max_id = cls_max_ids[ind]
                        # 獲取所有相關信息,包括長,寬,位置,置信度和類別
                        box = [bcx / w, bcy / h, bw / w, bh / h,
                               det_conf, cls_max_conf, cls_max_id]
                        # 處理數據,其中位置信息x,y,尺寸信息w,h均歸一化,使其與輸入圖片尺寸解耦
                        if (not only_objectness) and validation:
                            for c in range(num_classes):
                                tmp_conf = cls_confs[ind][c]
                                if c != cls_max_id and det_confs[ind] * tmp_conf > conf_thresh:
                                    box.append(tmp_conf)
                                    box.append(c)
                        boxes.append(box)
                        # 將處理好的anchor box信息保存在boxes中
        all_boxes.append(boxes)
        return all_boxes

NMS部分

NMS也在utils.py中,函數名爲nms。該函數中,首先實現對所有候選框的排序。這裏使用det_confs獲取了置信度從大到小的anchor box的座標位置sortIds

def nms(boxes, nms_thresh):
    if len(boxes) == 0:
        return boxes

    det_confs = torch.zeros(len(boxes))
    for i in range(len(boxes)):
        det_confs[i] = 1 - boxes[i][4]

    _, sortIds = torch.sort(det_confs)

隨後實現候選框的篩選,從高置信度的候選框開始遍歷,對於每個候選框boxes[sortIds[i]],遍歷所有置信度低於該候選框且置信度不爲0(置信度爲0表示該候選框被拋棄)的候選框,若低置信度候選框與高置信度候選框的IOU大於閾值,則拋棄低置信度候選框。

    out_boxes = []
    for i in range(len(boxes)):
        # 按置信度從高到低遍歷
        box_i = boxes[sortIds[i]]
        if box_i[4] > 0:
            # 置信度大於0表示該候選框沒有在之前的篩選中被拋棄
            out_boxes.append(box_i)
            for j in range(i + 1, len(boxes)):
                # 遍歷所有置信度低於該候選框的候選框
                box_j = boxes[sortIds[j]]
                if bbox_iou(box_i, box_j, x1y1x2y2=False) > nms_thresh:
                    # 若置信度低的候選框與該候選框IOU大於一定值,拋棄低置信度候選框
                    box_j[4] = 0
    return out_boxes
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章