yolov3-darknet中yolov2参数hier_thresh的意义及作用

在yolov3的python接口中的darknet.py中的detect()函数中包含参数hier_thresh,具体函数如下:

def detect(net, meta, image, thresh=.5, hier_thresh=.5, nms=.45):
    im = load_image(image, 0, 0)
    num = c_int(0)
    pnum = pointer(num)
    predict_image(net, im)
    dets = get_network_boxes(net, im.w, im.h, thresh, hier_thresh, None, 0, pnum)
    num = pnum[0]
    if (nms): do_nms_obj(dets, num, meta.classes, nms);

    res = []
    for j in range(num):
        for i in range(meta.classes):
            if dets[j].prob[i] > 0:
                b = dets[j].bbox
                res.append((meta.names[i], dets[j].prob[i], (b.x, b.y, b.w, b.h)))
    res = sorted(res, key=lambda x: -x[1])
    free_image(im)
    free_detections(dets, num)
    return res

看到使用hier_thresh参数的是get_network_boxes()函数,该函数通过libdarknet.so导入到python环境中,原始的函数定义实现在network.c中,具体为:

/**
 * w, h: 原始待检测图片的宽高
 * thresh: 默认为0.5
 * hier: 默认为0.5
 * map: NULL
 * relative: 1,box座标值是否是相对的(归一化后的)
 * num: 存储object数量
 **/
detection *get_network_boxes(network *net, int w, int h, 
  float thresh, float hier, int *map, int relative, int *num){
  detection *dets = make_network_boxes(net, thresh, num);
  fill_network_boxes(net, w, h, thresh, hier, map, relative, dets);
  return dets;
}

可见参数hier_thresh传给get_network_boxes()中的参数hier,具体首通过network.c中的make_network_boxes()得到多有的检测框,具体函数实现为:

detection *make_network_boxes(network *net, float thresh, int *num){
    layer l = net->layers[net->n - 1];
    int i;
    int nboxes = num_detections(net, thresh); // 获取所有预测box数量,for YOLOv2 case: 13*13*5
    if(num) *num = nboxes;
    detection *dets = calloc(nboxes, sizeof(detection)); // 创建检测类
    for(i = 0; i < nboxes; ++i){
        dets[i].prob = calloc(l.classes, sizeof(float)); // 开辟内存空间,用于保存概率/检测置信度
        if(l.coords > 4){
            dets[i].mask = calloc(l.coords-4, sizeof(float));
        }
    }
    return dets;
}

然后hier和boxes信息dets传递给fill_network_boxes()函数使用,而fill_network_boxes()函数同样位于network.c中,具体函数为:

void fill_network_boxes(network *net, int w, int h, float thresh, float hier, 
  int *map, int relative, detection *dets){
    int j;
    for(j = 0; j < net->n; ++j){
        layer l = net->layers[j];
        if(l.type == YOLO){       // YOLOv2
            int count = get_yolo_detections(l, w, h, net->w, net->h, thresh, map, relative, dets);
            dets += count;
        }
        if(l.type == REGION){      // YOLOv2
            get_region_detections(l, w, h, net->w, net->h, thresh, map, hier, relative, dets);
            dets += l.w*l.h*l.n;
        }
        if(l.type == DETECTION){   // YOLOv1
            get_detection_detections(l, w, h, thresh, dets);
            dets += l.w*l.h*l.n;
        }
    }
}

可见,hier参数最终传递至get_region_detections()中,这对应yolov2中的region检测层,具体函数在region_layer.c中实现,函数具体为:

void get_region_detections(layer l, int w, int h, int netw, int neth, float thresh, int *map, float tree_thresh, int relative, detection *dets)
{
    int i,j,n,z;
    float *predictions = l.output;
    if (l.batch == 2) {
        float *flip = l.output + l.outputs;
        for (j = 0; j < l.h; ++j) {
            for (i = 0; i < l.w/2; ++i) {
                for (n = 0; n < l.n; ++n) {
                    for(z = 0; z < l.classes + l.coords + 1; ++z){
                        int i1 = z*l.w*l.h*l.n + n*l.w*l.h + j*l.w + i;
                        int i2 = z*l.w*l.h*l.n + n*l.w*l.h + j*l.w + (l.w - i - 1);
                        float swap = flip[i1];
                        flip[i1] = flip[i2];
                        flip[i2] = swap;
                        if(z == 0){
                            flip[i1] = -flip[i1];
                            flip[i2] = -flip[i2];
                        }
                    }
                }
            }
        }
        for(i = 0; i < l.outputs; ++i){
            l.output[i] = (l.output[i] + flip[i])/2.;
        }
    }
    for (i = 0; i < l.w*l.h; ++i){ // 13*13,最后一层的feature map/grid
      int row = i / l.w;  // 行
      int col = i % l.w;  // 列
      for(n = 0; n < l.n; ++n){ // grid cell上的预测box,5个,可看作是输出通道数
        int index = n*l.w*l.h + i;  // 位置下标
        for(j = 0; j < l.classes; ++j){
          // dets:所有预测box的数组
          dets[index].prob[j] = 0;  // 分类概率初始化为0
        }
        int obj_index  = entry_index(l, 0, n*l.w*l.h + i, l.coords); // 预测box的confidence下标
        int box_index  = entry_index(l, 0, n*l.w*l.h + i, 0);  // 预测box 下标
        int mask_index = entry_index(l, 0, n*l.w*l.h + i, 4);  // coords>4时才用到
        // predictions就是输出数组
        float scale = l.background ? 1 : predictions[obj_index]; // 预测box的confidence值
        // 从数据数据中获取当前box的x,y,w,h
        dets[index].bbox = get_region_box(predictions, l.biases, n, box_index, col, row, 
          l.w, l.h, l.w*l.h);
        dets[index].objectness = scale > thresh ? scale : 0;    // 判断是否是object
        if(dets[index].mask){  // coords>4时才用到
          for(j = 0; j < l.coords - 4; ++j){
            dets[index].mask[j] = l.output[mask_index + j*l.w*l.h];//从输出中拷贝掩码值
          }
        }
        // 第一种分类,其概率数据的位置下标
        int class_index = entry_index(l, 0, n*l.w*l.h + i, l.coords + !l.background);
        if(l.softmax_tree){  // cfg/yolo9000专有
          // 根据YOLOv2(二)中的式(5)进行概率链式相乘,得到每个分类的最终预测概率
          hierarchy_predictions(predictions + class_index, l.classes, l.softmax_tree, 0, 
    l.w*l.h);
          if(map){  // 手动提供的分类id的映射
            // 对于分类数量为200的某个数据集,将分类id映射到YOLO9000中,但是源码这里map为NULL
            for(j = 0; j < 200; ++j){
              int class_index = entry_index(l, 0, n*l.w*l.h + i, l.coords + 1 + map[j]);
              float prob = scale*predictions[class_index];
              dets[index].prob[j] = (prob > thresh) ? prob : 0;
            }
          } else {  // 自动获取最大预测概率对应的分类(尽可能的细粒度分类)
            // 获取最大预测概率对应的分类id
            int j =  hierarchy_top_prediction(predictions + class_index, 
                       l.softmax_tree, tree_thresh, l.w*l.h);
            dets[index].prob[j] = (scale > thresh) ? scale : 0;
          }
        } else {   // 非 cfg/yolo9000 网络结构
          if(dets[index].objectness){  // confidence大于阈值,认为有object
            for(j = 0; j < l.classes; ++j){
              // 当前预测box的第j种分类位置下标
              int class_index = entry_index(l, 0, n*l.w*l.h + i, l.coords + 1 + j);
              // Pr(object)*Pr(class_j|object)
              float prob = scale*predictions[class_index];
              dets[index].prob[j] = (prob > thresh) ? prob : 0;
            }
          }
        }
      }
    }
    // 以上dets中各个预测box的座标是针对network input尺寸的,
    // 然而还需要校正得到以原始图片尺寸为基准的box座标
    correct_region_boxes(dets, l.w*l.h*l.n, w, h, netw, neth, relative);

可见hier传递给了tree_thresh,而最终使用tree_thresh参数的是hierarchy_top_prediction()函数,该函数获取最大预测概率对应的分类id。该函数位于tree.c中,函数具体为:

int hierarchy_top_prediction(float *predictions, tree *hier, float thresh, int stride)
{
    float p = 1;
    int group = 0;
    int i;
    while(1){
        float max = 0;
        int max_i = 0;

        for(i = 0; i < hier->group_size[group]; ++i){
            int index = i + hier->group_offset[group];
            float val = predictions[(i + hier->group_offset[group])*stride];
            if(val > max){ //得到概率最大的预测框
                max_i = index;
                max = val;
            }
        }
        if(p*max > thresh){  //判别得到的最大概率预测框概率是否大于thresh
            p = p*max;
            group = hier->child[max_i];
            if(hier->child[max_i] < 0) return max_i;
        } else if (group == 0){
            return max_i;
        } else {
            return hier->parent[hier->group_offset[group]];
        }
    }
    return 0;
}

可见tree_thresh具体传递给thresh参数,其具体作用是判别最大概率预测框是否大于thresh。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章