【Object Detection】【maskrcnn_benckmark】Registry()類的實現和用途

本文主要記錄 maskrcnn_benckmark 中一個非常有用的 utility,Registry()類的研究筆記

1. Registry()的實現

{ROOT_DIR}/maskrcnn_benchmark/utils/registry中定義Registry().

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.


def _register_generic(module_dict, module_name, module):
    assert module_name not in module_dict
    module_dict[module_name] = module


class Registry(dict):
    '''
    A helper class for managing registering modules, it extends a dictionary
    and provides a register functions.

    Eg. creeting a registry:
        some_registry = Registry({"default": default_module})

    There're two ways of registering new modules:
    1): normal way is just calling register function:
        def foo():
            ...
        some_registry.register("foo_module", foo)
    2): used as decorator when declaring the module:
        @some_registry.register("foo_module")
        @some_registry.register("foo_modeul_nickname")
        def foo():
            ...

    Access of module is just like using a dictionary, eg:
        f = some_registry["foo_modeul"]
    '''
    def __init__(self, *args, **kwargs):
        super(Registry, self).__init__(*args, **kwargs)

    def register(self, module_name, module=None):
        # used as function call
        if module is not None:
            _register_generic(self, module_name, module)
            return

        # used as decorator
        def register_fn(fn):
            _register_generic(self, module_name, fn)
            return fn

        return register_fn

Registry()繼承自Python的內建類型dict()。因此,本質上一個Registry()實例其實就是一個字典(dict()),它在字典的基礎上添加了一個類方法register(self, module_name, module=None)。這個類方法本質上是一種添加字典鍵值對的方法,它有兩種使用方式:

  1. 直接調用類方法
    此時module不能爲None.
    module_name既是dict()中的key,module既是dict()中的value.
  2. 作爲裝飾器使用
    此時module必須爲None.
    module_name既是dict()中的key,module從裝飾器接受函數對象或類對象作爲value.

maskrcnn_benckmark中,Registry()主要用來管理類和函數。

2. Registry()類方法register的使用

In[2]: from maskrcnn_benchmark.utils.registry import Registry
In[3]: TEST_REGISTRY = Registry()
In[4]: TEST_REGISTRY
Out[4]: {}

2.1 直接調用類方法

In[5]: TEST_REGISTRY.register('1', 1)
In[6]: TEST_REGISTRY
Out[6]: {'1': 1}
In[7]: def func1(flag):
  ...:     if flag == 0:
  ...:         return
  ...:     if flag == 1:
  ...:         print('calling func1')
  ...:         
In[8]: TEST_REGISTRY.register('func1', func_1)
In[9]: TEST_REGISTRY['func1']
Out[9]: <function __main__.func_1(flag)>
In[10]: TEST_REGISTRY['func1'](0)
In[11]: TEST_REGISTRY['func1'](1)
calling func1
In[12]: TEST_REGISTRY
Out[12]: {'1': 1, 'func1': <function __main__.func1(flag)>}

2.2 作爲裝飾器使用

In[13]: @TEST_REGISTRY.register('func2')
   ...: def func2(flag):
   ...:     if flag == 0:
   ...:         return
   ...:     if flag == 1:
   ...:         print('calling func2')
   ...:         
In[14]: TEST_REGISTRY['func2']
Out[14]: <function __main__.func2(flag)>
In[15]: TEST_REGISTRY['func2'](0)
In[16]: TEST_REGISTRY['func2'](1)
calling func2
In[17]: @TEST_REGISTRY.register('Class1')
   ...: class Class1(object):
   ...:     def __init__(self):
   ...:         print('calling Class1')
   ...:         
In[18]: TEST_REGISTRY['Class1']
Out[18]: __main__.Class1
In[19]: TEST_REGISTRY['Class1']()
calling Class1
Out[19]: <__main__.Class1 at 0x7f6bfda056d8>
In[20]: instance1 = TEST_REGISTRY['Class1']()
calling Class1
In[21]: type(instance1)
Out[21]: __main__.Class1

嵌套使用,多個key對應同一個value

In[23]: @TEST_REGISTRY.register('func3-1')
   ...: @TEST_REGISTRY.register('func3-2')
   ...: @TEST_REGISTRY.register('func3-3')
   ...: def func3():
   ...:     print('calling func3')
   ...:     
In[24]: TEST_REGISTRY['func3-1']()
calling func3
In[25]: TEST_REGISTRY['func3-2']()
calling func3
In[26]: TEST_REGISTRY['func3-3']()
calling func3

3. Registry()在maskrcnn_benckmark中的使用

maskrcnn_benckmark中,Registry()主要用於輔助yacs配置文件系統,幫助管理模型的組件。
首先在{ROOT_DIR}/maskrcnn_benchmark/modeling/registry.py中創建實例:

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

from maskrcnn_benchmark.utils.registry import Registry

BACKBONES = Registry()
RPN_HEADS = Registry()
ROI_BOX_FEATURE_EXTRACTORS = Registry()
ROI_BOX_PREDICTOR = Registry()
ROI_KEYPOINT_FEATURE_EXTRACTORS = Registry()
ROI_KEYPOINT_PREDICTOR = Registry()
ROI_MASK_FEATURE_EXTRACTORS = Registry()
ROI_MASK_PREDICTOR = Registry()

在下列各文件中以裝飾器方式調用類方法register(self, module_name, module=None):

  1. {ROOT_DIR}/maskrcnn_benchmark/modeling/backbone/backbone.py
  2. {ROOT_DIR}/maskrcnn_benchmark/modeling/rpn/rpn.py
  3. {ROOT_DIR}/maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_feature_extractors.py
  4. {ROOT_DIR}/maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_predictors.py
  5. {ROOT_DIR}/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/roi_keypoint_feature_extractors.py
  6. {ROOT_DIR}/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/roi_keypoint_predictors.py
  7. {ROOT_DIR}/maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_mask_feature_extractors.py
  8. {ROOT_DIR}/maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_mask_predictors.py

在上述文件中將{ROOT_DIR}/maskrcnn_benchmark/modeling/registry.py作爲module導入:

In[2]: from maskrcnn_benchmark.modeling import registry
In[3]: type(registry)
Out[3]: module
In[4]: registry.BACKBONES
Out[4]: {}

以下是導入後,完成定義的結果:

In[5]: from maskrcnn_benchmark.modeling.backbone.backbone import registry
In[6]: registry.BACKBONES
Out[6]: 
{'R-101-C5': <function maskrcnn_benchmark.modeling.backbone.backbone.build_resnet_backbone(cfg)>,
 'R-101-C4': <function maskrcnn_benchmark.modeling.backbone.backbone.build_resnet_backbone(cfg)>,
 'R-50-C5': <function maskrcnn_benchmark.modeling.backbone.backbone.build_resnet_backbone(cfg)>,
 'R-50-C4': <function maskrcnn_benchmark.modeling.backbone.backbone.build_resnet_backbone(cfg)>,
 'R-152-FPN': <function maskrcnn_benchmark.modeling.backbone.backbone.build_resnet_fpn_backbone(cfg)>,
 'R-101-FPN': <function maskrcnn_benchmark.modeling.backbone.backbone.build_resnet_fpn_backbone(cfg)>,
 'R-50-FPN': <function maskrcnn_benchmark.modeling.backbone.backbone.build_resnet_fpn_backbone(cfg)>,
 'R-101-FPN-RETINANET': <function maskrcnn_benchmark.modeling.backbone.backbone.build_resnet_fpn_p3p7_backbone(cfg)>,
 'R-50-FPN-RETINANET': <function maskrcnn_benchmark.modeling.backbone.backbone.build_resnet_fpn_p3p7_backbone(cfg)>,
 'FBNet': <function maskrcnn_benchmark.modeling.backbone.fbnet.add_conv_body(cfg, dim_in=3)>}
In[7]: registry.RPN_HEADS
Out[7]: 
{'SingleConvRPNHead': maskrcnn_benchmark.modeling.rpn.rpn.RPNHead,
 'FBNet.rpn_head': <function maskrcnn_benchmark.modeling.backbone.fbnet.add_rpn_head(cfg, in_channels, num_anchors)>}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章