本文主要記錄 maskrcnn_benckmark 中一個非常有用的 utility,Registry()類的研究筆記
文章目錄
1. Registry()的實現
在{ROOT_DIR}/maskrcnn_benchmark/utils/registry
中定義Registry()
.
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
def _register_generic(module_dict, module_name, module):
assert module_name not in module_dict
module_dict[module_name] = module
class Registry(dict):
'''
A helper class for managing registering modules, it extends a dictionary
and provides a register functions.
Eg. creeting a registry:
some_registry = Registry({"default": default_module})
There're two ways of registering new modules:
1): normal way is just calling register function:
def foo():
...
some_registry.register("foo_module", foo)
2): used as decorator when declaring the module:
@some_registry.register("foo_module")
@some_registry.register("foo_modeul_nickname")
def foo():
...
Access of module is just like using a dictionary, eg:
f = some_registry["foo_modeul"]
'''
def __init__(self, *args, **kwargs):
super(Registry, self).__init__(*args, **kwargs)
def register(self, module_name, module=None):
# used as function call
if module is not None:
_register_generic(self, module_name, module)
return
# used as decorator
def register_fn(fn):
_register_generic(self, module_name, fn)
return fn
return register_fn
Registry()
繼承自Python的內建類型dict()
。因此,本質上一個Registry()
實例其實就是一個字典(dict()
),它在字典的基礎上添加了一個類方法register(self, module_name, module=None)
。這個類方法本質上是一種添加字典鍵值對的方法,它有兩種使用方式:
- 直接調用類方法
此時module不能爲None.
module_name既是dict()中的key,module既是dict()中的value. - 作爲裝飾器使用
此時module必須爲None.
module_name既是dict()中的key,module從裝飾器接受函數對象或類對象作爲value.
在maskrcnn_benckmark
中,Registry()
主要用來管理類和函數。
2. Registry()類方法register的使用
In[2]: from maskrcnn_benchmark.utils.registry import Registry
In[3]: TEST_REGISTRY = Registry()
In[4]: TEST_REGISTRY
Out[4]: {}
2.1 直接調用類方法
In[5]: TEST_REGISTRY.register('1', 1)
In[6]: TEST_REGISTRY
Out[6]: {'1': 1}
In[7]: def func1(flag):
...: if flag == 0:
...: return
...: if flag == 1:
...: print('calling func1')
...:
In[8]: TEST_REGISTRY.register('func1', func_1)
In[9]: TEST_REGISTRY['func1']
Out[9]: <function __main__.func_1(flag)>
In[10]: TEST_REGISTRY['func1'](0)
In[11]: TEST_REGISTRY['func1'](1)
calling func1
In[12]: TEST_REGISTRY
Out[12]: {'1': 1, 'func1': <function __main__.func1(flag)>}
2.2 作爲裝飾器使用
In[13]: @TEST_REGISTRY.register('func2')
...: def func2(flag):
...: if flag == 0:
...: return
...: if flag == 1:
...: print('calling func2')
...:
In[14]: TEST_REGISTRY['func2']
Out[14]: <function __main__.func2(flag)>
In[15]: TEST_REGISTRY['func2'](0)
In[16]: TEST_REGISTRY['func2'](1)
calling func2
In[17]: @TEST_REGISTRY.register('Class1')
...: class Class1(object):
...: def __init__(self):
...: print('calling Class1')
...:
In[18]: TEST_REGISTRY['Class1']
Out[18]: __main__.Class1
In[19]: TEST_REGISTRY['Class1']()
calling Class1
Out[19]: <__main__.Class1 at 0x7f6bfda056d8>
In[20]: instance1 = TEST_REGISTRY['Class1']()
calling Class1
In[21]: type(instance1)
Out[21]: __main__.Class1
嵌套使用,多個key對應同一個value
In[23]: @TEST_REGISTRY.register('func3-1')
...: @TEST_REGISTRY.register('func3-2')
...: @TEST_REGISTRY.register('func3-3')
...: def func3():
...: print('calling func3')
...:
In[24]: TEST_REGISTRY['func3-1']()
calling func3
In[25]: TEST_REGISTRY['func3-2']()
calling func3
In[26]: TEST_REGISTRY['func3-3']()
calling func3
3. Registry()在maskrcnn_benckmark中的使用
在maskrcnn_benckmark
中,Registry()
主要用於輔助yacs
配置文件系統,幫助管理模型的組件。
首先在{ROOT_DIR}/maskrcnn_benchmark/modeling/registry.py
中創建實例:
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from maskrcnn_benchmark.utils.registry import Registry
BACKBONES = Registry()
RPN_HEADS = Registry()
ROI_BOX_FEATURE_EXTRACTORS = Registry()
ROI_BOX_PREDICTOR = Registry()
ROI_KEYPOINT_FEATURE_EXTRACTORS = Registry()
ROI_KEYPOINT_PREDICTOR = Registry()
ROI_MASK_FEATURE_EXTRACTORS = Registry()
ROI_MASK_PREDICTOR = Registry()
在下列各文件中以裝飾器方式調用類方法register(self, module_name, module=None)
:
{ROOT_DIR}/maskrcnn_benchmark/modeling/backbone/backbone.py
{ROOT_DIR}/maskrcnn_benchmark/modeling/rpn/rpn.py
{ROOT_DIR}/maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_feature_extractors.py
{ROOT_DIR}/maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_predictors.py
{ROOT_DIR}/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/roi_keypoint_feature_extractors.py
{ROOT_DIR}/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/roi_keypoint_predictors.py
{ROOT_DIR}/maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_mask_feature_extractors.py
{ROOT_DIR}/maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_mask_predictors.py
在上述文件中將{ROOT_DIR}/maskrcnn_benchmark/modeling/registry.py
作爲module導入:
In[2]: from maskrcnn_benchmark.modeling import registry
In[3]: type(registry)
Out[3]: module
In[4]: registry.BACKBONES
Out[4]: {}
以下是導入後,完成定義的結果:
In[5]: from maskrcnn_benchmark.modeling.backbone.backbone import registry
In[6]: registry.BACKBONES
Out[6]:
{'R-101-C5': <function maskrcnn_benchmark.modeling.backbone.backbone.build_resnet_backbone(cfg)>,
'R-101-C4': <function maskrcnn_benchmark.modeling.backbone.backbone.build_resnet_backbone(cfg)>,
'R-50-C5': <function maskrcnn_benchmark.modeling.backbone.backbone.build_resnet_backbone(cfg)>,
'R-50-C4': <function maskrcnn_benchmark.modeling.backbone.backbone.build_resnet_backbone(cfg)>,
'R-152-FPN': <function maskrcnn_benchmark.modeling.backbone.backbone.build_resnet_fpn_backbone(cfg)>,
'R-101-FPN': <function maskrcnn_benchmark.modeling.backbone.backbone.build_resnet_fpn_backbone(cfg)>,
'R-50-FPN': <function maskrcnn_benchmark.modeling.backbone.backbone.build_resnet_fpn_backbone(cfg)>,
'R-101-FPN-RETINANET': <function maskrcnn_benchmark.modeling.backbone.backbone.build_resnet_fpn_p3p7_backbone(cfg)>,
'R-50-FPN-RETINANET': <function maskrcnn_benchmark.modeling.backbone.backbone.build_resnet_fpn_p3p7_backbone(cfg)>,
'FBNet': <function maskrcnn_benchmark.modeling.backbone.fbnet.add_conv_body(cfg, dim_in=3)>}
In[7]: registry.RPN_HEADS
Out[7]:
{'SingleConvRPNHead': maskrcnn_benchmark.modeling.rpn.rpn.RPNHead,
'FBNet.rpn_head': <function maskrcnn_benchmark.modeling.backbone.fbnet.add_rpn_head(cfg, in_channels, num_anchors)>}