【代碼】mmdetection框架

0.前言

這篇文章是使用mmdetection的一些記錄,記錄對於代碼、設計理念的個人理解。

1.train

使用tools.train進行訓練。添加如下代碼來使用debug模式:

    import os
    os.environ["CUDA_VISIBLE_DEVICES"] = "4"
    args = ['./configs/cascade_mask_rcnn_r101_fpn_1x.py',
            '--gpus', '1',
            '--work_dir', 'cascade_mask_rcnn_r101_fpn_1x'
            ]

1.1. 首先是建立模型:

    model = build_detector(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
def build_detector(cfg, train_cfg=None, test_cfg=None):
    return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))

其中,cfg是config文件,DETECTORS爲全局對象,在models/registry中創建,是一個Registry對象。Registry類含_name和_module_dict屬性,在一開始只將_name賦予’detector’等字符。在每個與檢測器有關的類之前都有 @DETECTORS.register_module 修飾器,它可以將這個類以及其名字(_name_屬性)在DETECTORS的_module_dict中。
build調用build_from_cfg,首先取出cfg建立的對象類型obj_type,使用get從註冊器(Registry對象)中取出相應的類,使用inspect來判斷取出的obj_type是否是類。之後使用obj_type(類)將args(就是cfg)作爲參數進行實例化。

def build_from_cfg(cfg, registry, default_args=None):
    """Build a module from config dict.

    Args:
        cfg (dict): Config dict. It should at least contain the key "type".
        registry (:obj:`Registry`): The registry to search the type from.
        default_args (dict, optional): Default initialization arguments.

    Returns:
        obj: The constructed object.
    """
    assert isinstance(cfg, dict) and 'type' in cfg
    assert isinstance(default_args, dict) or default_args is None
    args = cfg.copy()
    obj_type = args.pop('type')	# 對象的名字,比如CascadeRCNN
    if mmcv.is_str(obj_type):
        obj_type = registry.get(obj_type)
        if obj_type is None:
            raise KeyError('{} is not in the {} registry'.format(
                obj_type, registry.name))
    elif not inspect.isclass(obj_type):
        raise TypeError('type must be a str or valid type, but got {}'.format(
            type(obj_type)))
    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)
    return obj_type(**args)   #返回實例化

1.2. 建立train_dataset:

同樣的套路,build調用build_from_cfg,按照cfg中的描述進行實例化,只是cfg是dataset的cfg。

1.3. 訓練:

    train_detector(
        model,
        train_dataset,
        cfg,
        distributed=distributed,
        validate=args.validate,
        logger=logger)

train_detector調用non_dist_train,在這裏將model並行化,建立data_loaderoptimizerrunner

1.3.1 建立data_loader:
##################_non_dist_train部分
    # prepare data loaders
    data_loaders = [
        build_dataloader(
            dataset,
            cfg.data.imgs_per_gpu,
            cfg.data.workers_per_gpu,
            cfg.gpus,
            dist=False)
    ]

##################build_dataloader函數
def build_dataloader(dataset,
                     imgs_per_gpu,
                     workers_per_gpu,
                     num_gpus=1,
                     dist=True,
                     **kwargs):
    shuffle = kwargs.get('shuffle', True)
    if dist:
        rank, world_size = get_dist_info()
        if shuffle:
            sampler = DistributedGroupSampler(dataset, imgs_per_gpu,
                                              world_size, rank)
        else:
            sampler = DistributedSampler(
                dataset, world_size, rank, shuffle=False)
        batch_size = imgs_per_gpu
        num_workers = workers_per_gpu
    else:
        sampler = GroupSampler(dataset, imgs_per_gpu) if shuffle else None
        batch_size = num_gpus * imgs_per_gpu
        num_workers = num_gpus * workers_per_gpu

    data_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=sampler,
        num_workers=num_workers,
        collate_fn=partial(collate, samples_per_gpu=imgs_per_gpu),
        pin_memory=False,
        **kwargs)

build_dataloader中主要是創建了兩個對象samplercollate(通過偏函數partial來創建),前者是採樣器,採樣出下標,後者是整理器,用於組成batch輸出。之後使用pytorch自帶的DataLoader就行了。sampler考慮了並行操作。collate除了支持對於Sequence,Mapping的batch構建外,更重要的是有對於DataContainer類型數據的batch操作,這是一個mmdet中創建的新類型,支持多種數據類型。

1.3.3 建立runner:
    runner = Runner(model, batch_processor, optimizer, cfg.work_dir,
                    cfg.log_level)
############batch_processor的定義
def batch_processor(model, data, train_mode):
    losses = model(**data)
    loss, log_vars = parse_losses(losses)

    outputs = dict(
        loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))

    return outputs

其中batch_processor調用model來得到loss(model的forward得到的是loss而不是網絡的輸出)。之後對loss進行一些小處理。
runner的初始化基本上就是model, optimizer, work_dir等的初始化。

1.3.4 runner運行:
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
# 刪除部分
    def run(self, data_loaders, workflow, max_epochs, **kwargs):
        assert isinstance(data_loaders, list)
        assert mmcv.is_list_of(workflow, tuple)
        assert len(data_loaders) == len(workflow)

        self._max_epochs = max_epochs
        work_dir = self.work_dir if self.work_dir is not None else 'NONE'
        self.logger.info('Start running, host: %s, work_dir: %s',
                         get_host_info(), work_dir)
        self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs)
        self.call_hook('before_run')

        while self.epoch < max_epochs:
            for i, flow in enumerate(workflow):
                mode, epochs = flow
                if isinstance(mode, str): 
                    epoch_runner = getattr(self, mode)
                elif callable(mode):  # custom train()
                    epoch_runner = mode
                for _ in range(epochs):
                    if mode == 'train' and self.epoch >= max_epochs:
                        return
                    epoch_runner(data_loaders[i], **kwargs)

        time.sleep(1)  # wait for some hooks like loggers to finish
        self.call_hook('after_run')

workflow代表的是工作流程E.g, [(‘train’, 2), (‘val’, 1)] ;run中通過getattr獲得epoch_runner,一般就是runner.trainrunner.val,前者就是一般的train過程,首先self.model.train()來避免eval狀態。之後就是一般的train了,裏面有用到多處的call_hook

    def call_hook(self, fn_name):
        for hook in self._hooks:
            getattr(hook, fn_name)(self)

call_hook的作用就是一個一個的hook的fn_name這個函數作用到自身,獲得一些或者改變一些信息吧。

2. 思路

runner控制模型的訓練、驗證和測試過程。
dataloader負責數據的導入。
模型中anchor生成、anchor匹配等操作均隱藏在了model中,model又分爲
與anchor有關的head:anchor_head
主幹:backbones
ROI有關的head:bbox_heads
檢測器本體:detectors
損失函數:losses
與mask有關的head:mask_heads
backbone進一步基礎上的特徵提取module:necks
attention機制等插件:plugins
roi提取器:roiextractors
不知道是啥:shared_heads
所有與技術細節有關的部分都放在了這些model當中。這些model也會調用core中的函數。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章