0.前言
這篇文章是使用mmdetection的一些記錄,記錄對於代碼、設計理念的個人理解。
1.train
使用tools.train進行訓練。添加如下代碼來使用debug模式:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
args = ['./configs/cascade_mask_rcnn_r101_fpn_1x.py',
'--gpus', '1',
'--work_dir', 'cascade_mask_rcnn_r101_fpn_1x'
]
1.1. 首先是建立模型:
model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
def build_detector(cfg, train_cfg=None, test_cfg=None):
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
其中,cfg是config文件,DETECTORS爲全局對象,在models/registry中創建,是一個Registry對象。Registry類含_name和_module_dict屬性,在一開始只將_name賦予’detector’等字符。在每個與檢測器有關的類之前都有 @DETECTORS.register_module 修飾器,它可以將這個類以及其名字(_name_屬性)在DETECTORS的_module_dict中。
build調用build_from_cfg,首先取出cfg建立的對象類型obj_type,使用get從註冊器(Registry對象)中取出相應的類,使用inspect來判斷取出的obj_type是否是類。之後使用obj_type(類)將args(就是cfg)作爲參數進行實例化。
def build_from_cfg(cfg, registry, default_args=None):
"""Build a module from config dict.
Args:
cfg (dict): Config dict. It should at least contain the key "type".
registry (:obj:`Registry`): The registry to search the type from.
default_args (dict, optional): Default initialization arguments.
Returns:
obj: The constructed object.
"""
assert isinstance(cfg, dict) and 'type' in cfg
assert isinstance(default_args, dict) or default_args is None
args = cfg.copy()
obj_type = args.pop('type') # 對象的名字,比如CascadeRCNN
if mmcv.is_str(obj_type):
obj_type = registry.get(obj_type)
if obj_type is None:
raise KeyError('{} is not in the {} registry'.format(
obj_type, registry.name))
elif not inspect.isclass(obj_type):
raise TypeError('type must be a str or valid type, but got {}'.format(
type(obj_type)))
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
return obj_type(**args) #返回實例化
1.2. 建立train_dataset:
同樣的套路,build調用build_from_cfg,按照cfg中的描述進行實例化,只是cfg是dataset的cfg。
1.3. 訓練:
train_detector(
model,
train_dataset,
cfg,
distributed=distributed,
validate=args.validate,
logger=logger)
train_detector調用non_dist_train,在這裏將model並行化,建立data_loader、optimizer、runner
1.3.1 建立data_loader:
##################_non_dist_train部分
# prepare data loaders
data_loaders = [
build_dataloader(
dataset,
cfg.data.imgs_per_gpu,
cfg.data.workers_per_gpu,
cfg.gpus,
dist=False)
]
##################build_dataloader函數
def build_dataloader(dataset,
imgs_per_gpu,
workers_per_gpu,
num_gpus=1,
dist=True,
**kwargs):
shuffle = kwargs.get('shuffle', True)
if dist:
rank, world_size = get_dist_info()
if shuffle:
sampler = DistributedGroupSampler(dataset, imgs_per_gpu,
world_size, rank)
else:
sampler = DistributedSampler(
dataset, world_size, rank, shuffle=False)
batch_size = imgs_per_gpu
num_workers = workers_per_gpu
else:
sampler = GroupSampler(dataset, imgs_per_gpu) if shuffle else None
batch_size = num_gpus * imgs_per_gpu
num_workers = num_gpus * workers_per_gpu
data_loader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=partial(collate, samples_per_gpu=imgs_per_gpu),
pin_memory=False,
**kwargs)
build_dataloader中主要是創建了兩個對象sampler和collate(通過偏函數partial來創建),前者是採樣器,採樣出下標,後者是整理器,用於組成batch輸出。之後使用pytorch自帶的DataLoader就行了。sampler考慮了並行操作。collate除了支持對於Sequence,Mapping的batch構建外,更重要的是有對於DataContainer類型數據的batch操作,這是一個mmdet中創建的新類型,支持多種數據類型。
1.3.3 建立runner:
runner = Runner(model, batch_processor, optimizer, cfg.work_dir,
cfg.log_level)
############batch_processor的定義
def batch_processor(model, data, train_mode):
losses = model(**data)
loss, log_vars = parse_losses(losses)
outputs = dict(
loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))
return outputs
其中batch_processor調用model來得到loss(model的forward得到的是loss而不是網絡的輸出)。之後對loss進行一些小處理。
runner的初始化基本上就是model, optimizer, work_dir等的初始化。
1.3.4 runner運行:
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
# 刪除部分
def run(self, data_loaders, workflow, max_epochs, **kwargs):
assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow)
self._max_epochs = max_epochs
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs)
self.call_hook('before_run')
while self.epoch < max_epochs:
for i, flow in enumerate(workflow):
mode, epochs = flow
if isinstance(mode, str):
epoch_runner = getattr(self, mode)
elif callable(mode): # custom train()
epoch_runner = mode
for _ in range(epochs):
if mode == 'train' and self.epoch >= max_epochs:
return
epoch_runner(data_loaders[i], **kwargs)
time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_run')
workflow代表的是工作流程E.g, [(‘train’, 2), (‘val’, 1)] ;run中通過getattr獲得epoch_runner,一般就是runner.train和runner.val,前者就是一般的train過程,首先self.model.train()來避免eval狀態。之後就是一般的train了,裏面有用到多處的call_hook
def call_hook(self, fn_name):
for hook in self._hooks:
getattr(hook, fn_name)(self)
call_hook的作用就是一個一個的hook的fn_name這個函數作用到自身,獲得一些或者改變一些信息吧。
2. 思路
runner控制模型的訓練、驗證和測試過程。
dataloader負責數據的導入。
模型中anchor生成、anchor匹配等操作均隱藏在了model中,model又分爲
與anchor有關的head:anchor_head
主幹:backbones,
ROI有關的head:bbox_heads,
檢測器本體:detectors
損失函數:losses
與mask有關的head:mask_heads
backbone進一步基礎上的特徵提取module:necks
attention機制等插件:plugins
roi提取器:roiextractors
不知道是啥:shared_heads
所有與技術細節有關的部分都放在了這些model當中。這些model也會調用core中的函數。