Faster R-CNN 数据组织代码解析

最近想花点时间对Faster R-CNN等经典的目标检测代码进行注释和学习,同时留下学习笔记,和大家一同努力进步~~~于是乎,有着此篇的博文。
Faster R-CNN作为RCNN系列的第三篇文章,主要为RPN网络的训练,再进行Fast rcnn的训练,两个部分交替训练的方式,最终得到满意的目标检测结果。本文主要对Fast RCNN部分的数据组织代码( fast_rcnn_prepare_image_roidb函数)进行解析,使用matlab源码。数据组织部分的代码作为修改代码为自己所用的第一步,学习很有必要。
代码引用自:https://github.com/ShaoqingRen/faster_rcnn

function [image_roidb, bbox_means, bbox_stds] = fast_rcnn_prepare_image_roidb(conf, imdbs, roidbs, bbox_means, bbox_stds)
% [image_roidb, bbox_means, bbox_stds] = fast_rcnn_prepare_image_roidb(conf, imdbs, roidbs, cache_img, bbox_means, bbox_stds)
% 
% 从imdb和roidb数据库文件中计算bounding box均值和标准差,用于回归于标准化
% --------------------------------------------------------
% Fast R-CNN
% Reimplementation based on Python Fast R-CNN (https://github.com/rbgirshick/fast-rcnn)
% Copyright (c) 2015, Shaoqing Ren
% Licensed under The MIT License [see LICENSE for details]
% -------------------------------------------------------- 

    if ~exist('bbox_means', 'var')
        bbox_means = [];
        bbox_stds = [];
    end

    if ~iscell(imdbs)
        imdbs = {imdbs};
        roidbs = {roidbs};
    end

    imdbs = imdbs(:);
    roidbs = roidbs(:);

    % 将数据结构体里的项重新组织,添加image与bbox_targets项
    image_roidb = ...
        cellfun(@(x, y) ... // @(imdbs, roidbs)
                arrayfun(@(z) ... //@([1:length(x.image_ids)])
                        struct('image_path', x.image_at(z), 'image_id', x.image_ids{z}, 'im_size', x.sizes(z, :), 'imdb_name', x.name, ...
                        'overlap', y.rois(z).overlap, 'boxes', y.rois(z).boxes, 'class', y.rois(z).class, 'image', [], 'bbox_targets', []), ...
                [1:length(x.image_ids)]', 'UniformOutput', true),...
        imdbs, roidbs, 'UniformOutput', false);

    image_roidb = cat(1, image_roidb{:});

    % 保证roidb中包含bounding-box与targets的回归量,具体计算见下面的函数
    [image_roidb, bbox_means, bbox_stds] = append_bbox_regression_targets(conf, image_roidb, bbox_means, bbox_stds);
end

function [image_roidb, means, stds] = append_bbox_regression_targets(conf, image_roidb, means, stds)
    % means and stds -- (k+1) * 4, include background class

    num_images = length(image_roidb);

    % 从gt_overlaps列判断一共有几类
    num_classes = size(image_roidb(1).overlap, 2);

    % valid_imgs:判断是否所有图像都有正负样本,即产生proposal于gt重叠有大于阈值的和小于阈值的,如果都为0,则剔除该图
    valid_imgs = true(num_images, 1);
    for i = 1:num_images
       rois = image_roidb(i).boxes; % rois为proposal阶段产生的boxes
       [image_roidb(i).bbox_targets, valid_imgs(i)] = ...
           compute_targets(conf, rois, image_roidb(i).overlap); % 计算bbox_tragets和需要剔除图片,见最后的函数
    end
    if ~all(valid_imgs)
        image_roidb = image_roidb(valid_imgs);
        num_images = length(image_roidb);
        fprintf('Warning: fast_rcnn_prepare_image_roidb: filter out %d images, which contains zero valid samples\n', sum(~valid_imgs));
    end

    if ~(exist('means', 'var') && ~isempty(means) && exist('stds', 'var') && ~isempty(stds))
        % 计算均值和标准差
        % var(x) = E(x^2) - E(x)^2
        class_counts = zeros(num_classes, 1) + eps;
        sums = zeros(num_classes, 4);
        squared_sums = zeros(num_classes, 4);
        for i = 1:num_images
           targets = image_roidb(i).bbox_targets;%image_roidb(i).bbox_targets为标记的可以认为是正负样本的proposal box
           for cls = 1:num_classes
              cls_inds = find(targets(:, 1) == cls);
              if ~isempty(cls_inds)
                 class_counts(cls) = class_counts(cls) + length(cls_inds); 
                 sums(cls, :) = sums(cls, :) + sum(targets(cls_inds, 2:end), 1);
                 squared_sums(cls, :) = squared_sums(cls, :) + sum(targets(cls_inds, 2:end).^2, 1);%loss
              end
           end
        end

        % 类别均值与标准差
        means = bsxfun(@rdivide, sums, class_counts);
        stds = (bsxfun(@minus, bsxfun(@rdivide, squared_sums, class_counts), means.^2)).^0.5;

        % 添加背景类
        means = [0, 0, 0, 0; means]; 
        stds = [0, 0, 0, 0; stds];
    end

    % 对targets进行规范化,完成fast_rcnn部分所需要的完整数据类型
    for i = 1:num_images
        targets = image_roidb(i).bbox_targets;
        for cls = 1:num_classes
            cls_inds = find(targets(:, 1) == cls);
            if ~isempty(cls_inds)
                image_roidb(i).bbox_targets(cls_inds, 2:end) = ...
                    bsxfun(@minus, image_roidb(i).bbox_targets(cls_inds, 2:end), means(cls+1, :));
                image_roidb(i).bbox_targets(cls_inds, 2:end) = ...
                    bsxfun(@rdivide, image_roidb(i).bbox_targets(cls_inds, 2:end), stds(cls+1, :));
            end
        end
    end
end


function [bbox_targets, is_valid] = compute_targets(conf, rois, overlap) % 计算bbox_tragets和需要剔除图片

    overlap = full(overlap);

    [max_overlaps, max_labels] = max(overlap, [], 2);

    % 确保ROIs是floats
    rois = single(rois);

    % 新建5列的变量bbox_targets
    bbox_targets = zeros(size(rois, 1), 5, 'single');

    % ground-truth ROIs用 gt_inds索引
    gt_inds = find(max_overlaps == 1);

    if ~isempty(gt_inds)

        % 当overlaps大于一定阈值的索引即ex_inds,这些bbox是希望用来做预测的bbox
        ex_inds = find(max_overlaps >= conf.bbox_thresh);

        % 对每一个 ex ROI 和 gt ROI 计算IoU重叠度
        ex_gt_overlaps = boxoverlap(rois(ex_inds, :), rois(gt_inds, :));

        assert(all(abs(max(ex_gt_overlaps, [], 2) - max_overlaps(ex_inds)) < 1^-6));

        % 找到ex ROI对应的最大重叠度的 gt ROI,作为该ex ROI的gt target
        [~, gt_assignment] = max(ex_gt_overlaps, [], 2);
        gt_rois = rois(gt_inds(gt_assignment), :);
        ex_rois = rois(ex_inds, :);

        % 根据我们用来预测的bbox:ex ROI和groundtruth对应的gt ROI计算下一步需要回归的座标
        [regression_label] = fast_rcnn_bbox_transform(ex_rois, gt_rois);

        % 用bbox_tragets记录作为正样本的bbox与回归座标
        bbox_targets(ex_inds, :) = [max_labels(ex_inds), regression_label];
    end

    % 标记proposal正样本, ROIs >= fg_thresh overlap
    is_fg = max_overlaps >= conf.fg_thresh;

    % 标记proposal负样本ROIs 处于 [bg_thresh_lo, bg_thresh_hi)
    is_bg = max_overlaps < conf.bg_thresh_hi & max_overlaps >= conf.bg_thresh_lo;

    % 当图中没有任何正负样本时,剔除该图
    is_valid = true;
    if ~any(is_fg | is_bg)
        is_valid = false;
    end
end
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章