yolov3-darknet中yolov2參數hier_thresh的意義及作用

在yolov3的python接口中的darknet.py中的detect()函數中包含參數hier_thresh,具體函數如下:

def detect(net, meta, image, thresh=.5, hier_thresh=.5, nms=.45):
    im = load_image(image, 0, 0)
    num = c_int(0)
    pnum = pointer(num)
    predict_image(net, im)
    dets = get_network_boxes(net, im.w, im.h, thresh, hier_thresh, None, 0, pnum)
    num = pnum[0]
    if (nms): do_nms_obj(dets, num, meta.classes, nms);

    res = []
    for j in range(num):
        for i in range(meta.classes):
            if dets[j].prob[i] > 0:
                b = dets[j].bbox
                res.append((meta.names[i], dets[j].prob[i], (b.x, b.y, b.w, b.h)))
    res = sorted(res, key=lambda x: -x[1])
    free_image(im)
    free_detections(dets, num)
    return res

看到使用hier_thresh參數的是get_network_boxes()函數,該函數通過libdarknet.so導入到python環境中,原始的函數定義實現在network.c中,具體爲:

/**
 * w, h: 原始待檢測圖片的寬高
 * thresh: 默認爲0.5
 * hier: 默認爲0.5
 * map: NULL
 * relative: 1,box座標值是否是相對的(歸一化後的)
 * num: 存儲object數量
 **/
detection *get_network_boxes(network *net, int w, int h, 
  float thresh, float hier, int *map, int relative, int *num){
  detection *dets = make_network_boxes(net, thresh, num);
  fill_network_boxes(net, w, h, thresh, hier, map, relative, dets);
  return dets;
}

可見參數hier_thresh傳給get_network_boxes()中的參數hier,具體首通過network.c中的make_network_boxes()得到多有的檢測框,具體函數實現爲:

detection *make_network_boxes(network *net, float thresh, int *num){
    layer l = net->layers[net->n - 1];
    int i;
    int nboxes = num_detections(net, thresh); // 獲取所有預測box數量,for YOLOv2 case: 13*13*5
    if(num) *num = nboxes;
    detection *dets = calloc(nboxes, sizeof(detection)); // 創建檢測類
    for(i = 0; i < nboxes; ++i){
        dets[i].prob = calloc(l.classes, sizeof(float)); // 開闢內存空間,用於保存概率/檢測置信度
        if(l.coords > 4){
            dets[i].mask = calloc(l.coords-4, sizeof(float));
        }
    }
    return dets;
}

然後hier和boxes信息dets傳遞給fill_network_boxes()函數使用,而fill_network_boxes()函數同樣位於network.c中,具體函數爲:

void fill_network_boxes(network *net, int w, int h, float thresh, float hier, 
  int *map, int relative, detection *dets){
    int j;
    for(j = 0; j < net->n; ++j){
        layer l = net->layers[j];
        if(l.type == YOLO){       // YOLOv2
            int count = get_yolo_detections(l, w, h, net->w, net->h, thresh, map, relative, dets);
            dets += count;
        }
        if(l.type == REGION){      // YOLOv2
            get_region_detections(l, w, h, net->w, net->h, thresh, map, hier, relative, dets);
            dets += l.w*l.h*l.n;
        }
        if(l.type == DETECTION){   // YOLOv1
            get_detection_detections(l, w, h, thresh, dets);
            dets += l.w*l.h*l.n;
        }
    }
}

可見,hier參數最終傳遞至get_region_detections()中,這對應yolov2中的region檢測層,具體函數在region_layer.c中實現,函數具體爲:

void get_region_detections(layer l, int w, int h, int netw, int neth, float thresh, int *map, float tree_thresh, int relative, detection *dets)
{
    int i,j,n,z;
    float *predictions = l.output;
    if (l.batch == 2) {
        float *flip = l.output + l.outputs;
        for (j = 0; j < l.h; ++j) {
            for (i = 0; i < l.w/2; ++i) {
                for (n = 0; n < l.n; ++n) {
                    for(z = 0; z < l.classes + l.coords + 1; ++z){
                        int i1 = z*l.w*l.h*l.n + n*l.w*l.h + j*l.w + i;
                        int i2 = z*l.w*l.h*l.n + n*l.w*l.h + j*l.w + (l.w - i - 1);
                        float swap = flip[i1];
                        flip[i1] = flip[i2];
                        flip[i2] = swap;
                        if(z == 0){
                            flip[i1] = -flip[i1];
                            flip[i2] = -flip[i2];
                        }
                    }
                }
            }
        }
        for(i = 0; i < l.outputs; ++i){
            l.output[i] = (l.output[i] + flip[i])/2.;
        }
    }
    for (i = 0; i < l.w*l.h; ++i){ // 13*13,最後一層的feature map/grid
      int row = i / l.w;  // 行
      int col = i % l.w;  // 列
      for(n = 0; n < l.n; ++n){ // grid cell上的預測box,5個,可看作是輸出通道數
        int index = n*l.w*l.h + i;  // 位置下標
        for(j = 0; j < l.classes; ++j){
          // dets:所有預測box的數組
          dets[index].prob[j] = 0;  // 分類概率初始化爲0
        }
        int obj_index  = entry_index(l, 0, n*l.w*l.h + i, l.coords); // 預測box的confidence下標
        int box_index  = entry_index(l, 0, n*l.w*l.h + i, 0);  // 預測box 下標
        int mask_index = entry_index(l, 0, n*l.w*l.h + i, 4);  // coords>4時纔用到
        // predictions就是輸出數組
        float scale = l.background ? 1 : predictions[obj_index]; // 預測box的confidence值
        // 從數據數據中獲取當前box的x,y,w,h
        dets[index].bbox = get_region_box(predictions, l.biases, n, box_index, col, row, 
          l.w, l.h, l.w*l.h);
        dets[index].objectness = scale > thresh ? scale : 0;    // 判斷是否是object
        if(dets[index].mask){  // coords>4時纔用到
          for(j = 0; j < l.coords - 4; ++j){
            dets[index].mask[j] = l.output[mask_index + j*l.w*l.h];//從輸出中拷貝掩碼值
          }
        }
        // 第一種分類,其概率數據的位置下標
        int class_index = entry_index(l, 0, n*l.w*l.h + i, l.coords + !l.background);
        if(l.softmax_tree){  // cfg/yolo9000專有
          // 根據YOLOv2(二)中的式(5)進行概率鏈式相乘,得到每個分類的最終預測概率
          hierarchy_predictions(predictions + class_index, l.classes, l.softmax_tree, 0, 
    l.w*l.h);
          if(map){  // 手動提供的分類id的映射
            // 對於分類數量爲200的某個數據集,將分類id映射到YOLO9000中,但是源碼這裏map爲NULL
            for(j = 0; j < 200; ++j){
              int class_index = entry_index(l, 0, n*l.w*l.h + i, l.coords + 1 + map[j]);
              float prob = scale*predictions[class_index];
              dets[index].prob[j] = (prob > thresh) ? prob : 0;
            }
          } else {  // 自動獲取最大預測概率對應的分類(儘可能的細粒度分類)
            // 獲取最大預測概率對應的分類id
            int j =  hierarchy_top_prediction(predictions + class_index, 
                       l.softmax_tree, tree_thresh, l.w*l.h);
            dets[index].prob[j] = (scale > thresh) ? scale : 0;
          }
        } else {   // 非 cfg/yolo9000 網絡結構
          if(dets[index].objectness){  // confidence大於閾值,認爲有object
            for(j = 0; j < l.classes; ++j){
              // 當前預測box的第j種分類位置下標
              int class_index = entry_index(l, 0, n*l.w*l.h + i, l.coords + 1 + j);
              // Pr(object)*Pr(class_j|object)
              float prob = scale*predictions[class_index];
              dets[index].prob[j] = (prob > thresh) ? prob : 0;
            }
          }
        }
      }
    }
    // 以上dets中各個預測box的座標是針對network input尺寸的,
    // 然而還需要校正得到以原始圖片尺寸爲基準的box座標
    correct_region_boxes(dets, l.w*l.h*l.n, w, h, netw, neth, relative);

可見hier傳遞給了tree_thresh,而最終使用tree_thresh參數的是hierarchy_top_prediction()函數,該函數獲取最大預測概率對應的分類id。該函數位於tree.c中,函數具體爲:

int hierarchy_top_prediction(float *predictions, tree *hier, float thresh, int stride)
{
    float p = 1;
    int group = 0;
    int i;
    while(1){
        float max = 0;
        int max_i = 0;

        for(i = 0; i < hier->group_size[group]; ++i){
            int index = i + hier->group_offset[group];
            float val = predictions[(i + hier->group_offset[group])*stride];
            if(val > max){ //得到概率最大的預測框
                max_i = index;
                max = val;
            }
        }
        if(p*max > thresh){  //判別得到的最大概率預測框概率是否大於thresh
            p = p*max;
            group = hier->child[max_i];
            if(hier->child[max_i] < 0) return max_i;
        } else if (group == 0){
            return max_i;
        } else {
            return hier->parent[hier->group_offset[group]];
        }
    }
    return 0;
}

可見tree_thresh具體傳遞給thresh參數,其具體作用是判別最大概率預測框是否大於thresh。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章