訓練流程圖
最終會創建一個runner,然後調用runner.run時,實際會根據workflow中是train還是val,調用runner.py下的train和val函數。
batch_processor
def batch_processor(model, data, train_mode):
# 這裏的train_mode實際沒用到
losses = model(**data)
loss, log_vars = parse_losses(losses)
outputs = dict(
loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))
return outputs
mmcv/runner/runner.py
train
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(data_loader)
self.call_hook('before_train_epoch')
for i, data_batch in enumerate(data_loader):
self._inner_iter = i
self.call_hook('before_train_iter')
outputs = self.batch_processor(
self.model, data_batch, train_mode=True, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('batch_processor() must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'],
outputs['num_samples'])
self.outputs = outputs
self.call_hook('after_train_iter')
self._iter += 1
self.call_hook('after_train_epoch')
self._epoch += 1
val
def val(self, data_loader, **kwargs):
self.model.eval()
self.mode = 'val'
self.data_loader = data_loader
self.call_hook('before_val_epoch')
for i, data_batch in enumerate(data_loader):
self._inner_iter = i
self.call_hook('before_val_iter')
with torch.no_grad():
outputs = self.batch_processor(
self.model, data_batch, train_mode=False, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('batch_processor() must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'],
outputs['num_samples'])
self.outputs = outputs
self.call_hook('after_val_iter')
self.call_hook('after_val_epoch')
validate目前只在_dist_train中有用到
訓練時,實際調用:losses = model(**data),驗證時,實際調用hook,運行:
with torch.no_grad():
result = runner.model(
return_loss=False, rescale=True, **data_gpu)
其中,TwoStageDetector和SingleStageDetector都繼承了BaseDetector,在BaseDetector中,forward函數定義如下:
@auto_fp16(apply_to=('img', ))
def forward(self, img, img_meta, return_loss=True, **kwargs):
if return_loss:
return self.forward_train(img, img_meta, **kwargs)
else:
return self.forward_test(img, img_meta, **kwargs)
對於forward_test,其代碼如下:
def forward_test(self, imgs, img_metas, **kwargs):
for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
if not isinstance(var, list):
raise TypeError('{} must be a list, but got {}'.format(
name, type(var)))
num_augs = len(imgs)
if num_augs != len(img_metas):
raise ValueError(
'num of augmentations ({}) != num of image meta ({})'.format(
len(imgs), len(img_metas)))
# TODO: remove the restriction of imgs_per_gpu == 1 when prepared
imgs_per_gpu = imgs[0].size(0)
assert imgs_per_gpu == 1
if num_augs == 1:
return self.simple_test(imgs[0], img_metas[0], **kwargs)
else:
return self.aug_test(imgs, img_metas, **kwargs)
由上可以看出,子類需要寫simple_test和aub_test函數。
對於一個檢測模型(一階或者二階),在其class中,需要重寫以下函數:
- forward_train
- simple_test
- aug_test # 非必須
下面以retinanet舉個例子,在retinanet的config文件中,model的type是RetinaNet,在mmdet/models/detectors/retinanet.py中,定義了RetinaNet,它的父類是SingleStageDetector,定義在mmdet/models/detectors/single_stage.py中,三個重要函數的代碼如下:
def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None):
x = self.extract_feat(img)
outs = self.bbox_head(x)
loss_inputs = outs + (gt_bboxes, gt_labels, img_metas, self.train_cfg)
losses = self.bbox_head.loss(
*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
return losses
def simple_test(self, img, img_meta, rescale=False):
x = self.extract_feat(img)
outs = self.bbox_head(x)
bbox_inputs = outs + (img_meta, self.test_cfg, rescale)
bbox_list = self.bbox_head.get_bboxes(*bbox_inputs)
bbox_results = [
bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
for det_bboxes, det_labels in bbox_list
]
return bbox_results[0]
def aug_test(self, imgs, img_metas, rescale=False):
raise NotImplementedError
由上可知,計算loss的函數是在head中定義的,RetinaHead定義在mmdet/models/anchor_heads/retina_head.py中,RetinaHead三個關鍵函數的代碼如下:
def _init_layers(self):
self.relu = nn.ReLU(inplace=True)
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels
self.cls_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
self.reg_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
self.retina_cls = nn.Conv2d(
self.feat_channels,
self.num_anchors * self.cls_out_channels,
3,
padding=1)
self.retina_reg = nn.Conv2d(
self.feat_channels, self.num_anchors * 4, 3, padding=1)
def init_weights(self):
for m in self.cls_convs:
normal_init(m.conv, std=0.01)
for m in self.reg_convs:
normal_init(m.conv, std=0.01)
bias_cls = bias_init_with_prob(0.01)
normal_init(self.retina_cls, std=0.01, bias=bias_cls)
normal_init(self.retina_reg, std=0.01)
def forward_single(self, x):
cls_feat = x
reg_feat = x
for cls_conv in self.cls_convs:
cls_feat = cls_conv(cls_feat)
for reg_conv in self.reg_convs:
reg_feat = reg_conv(reg_feat)
cls_score = self.retina_cls(cls_feat)
bbox_pred = self.retina_reg(reg_feat)
return cls_score, bbox_pred
其中,_init_layers創建head的結構,init_weights對conv的weight和bias做初始化,forward_single是經過head計算得到的分類和檢測框預測結果。
forward
在具體的方法對應的head定義forward_single,最後由anchor_head.py中的forward函數進行組裝。
from six.moves import map, zip
def multi_apply(func, *args, **kwargs):
pfunc = partial(func, **kwargs) if kwargs else func # 將func的kwargs固定,返回該函數
# 這裏的*args=feats,調用forward_single對feats的元素依次跑前向
map_results = map(pfunc, *args) # 得到[(stride1_cls,stride1_bbox,...), (stride2_cls,stride2_bbox, ...]
return tuple(map(list, zip(*map_results)))
# zip(*map_results) 得到 [(stride1_cls,stride2_cls,stride3_cls,...),(stride1_bbox,stride2_bbox,stride3_bbox,...)]
# map(list, zip(*map_results)) 將(stride1_cls,stride2_cls,stride3_cls,...)變爲[stride1_cls,stride2_cls,stride3_cls,...]
# tuple之後,最後得到([stride1_cls,stride2_cls,stride3_cls,...],[stride1_bbox,stride2_bbox,stride3_bbox,...])
def forward(self, feats):
# 輸入feats是一個list,長度爲stride個數,其中元素爲nchw
return multi_apply(self.forward_single, feats)
def forward_single(self, x):
# 這裏的x爲feats中的某一個元素
cls_feat = x
reg_feat = x
for cls_conv in self.cls_convs:
cls_feat = cls_conv(cls_feat)
for reg_conv in self.reg_convs:
reg_feat = reg_conv(reg_feat)
cls_score = self.retina_cls(cls_feat)
bbox_pred = self.retina_reg(reg_feat)
return cls_score, bbox_pred
loss