如果讀不懂,是一些語法不太會
該代碼基本上用的就是類的多級繼承運行訓練代碼,不直觀,
註冊函數實現數據不同讀取
然後就是hook 鉤子函數,添加多個任務事件
該工程主要是 三個reid 最新論文的實現方法,添加了一些訓練結構訓練技巧,集成到一起,
小白reid不建議這個工程,可讀性不是很好, 學習的話,從單個reid 論文方法先動手瞭解參數結構,然後再看
(菜鳥一枚,如有問題,歡迎批評)
代碼結構構建參考
https://github.com/facebookresearch/detectron2
https://zhuanlan.zhihu.com/p/96931265
detectron2(目標檢測框架)無死角玩轉-06:源碼詳解(2)-Trainer繼承關係,Hook
https://blog.csdn.net/weixin_43013761/article/details/104092658
1、super().init(model, data_loader, optimizer)
class DefaultTrainer(SimpleTrainer):
def __init__(self, cfg):
....
.....
##父類,SimpleTrainer,構造函數, def __init__(self, model, data_loader, optimizer):
##這裏子類自己寫了構造函數初始化,繼承父類的構造函數需要 寫 super
super().__init__(model, data_loader, optimizer)
類的繼承, 繼承後,子類寫構造函數__init__ 需要用super 初始化父類構造函數,纔可以繼承父類,不繼承父類構造函數
# default.py 類中, 有數據加載 data_loader = self.build_train_loader(cfg)
# 沒有重新定義 __init__(self) 默認繼承父類類所有構造函數
# 這裏直接跳到 DefaultTrainer類中
class Trainer(DefaultTrainer): #類的繼承,這裏繼承之後又添加了一個成員方法 build_evaluator
@classmethod
def build_evaluator(cls, cfg, num_query, output_folder=None):
if output_folder is None:
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
return ReidEvaluator(cfg, num_query)
2、register 註冊函數 讀取數據,
代碼中有四個裝飾器,不改變代碼功能的前提下增加函數功能
from …utils.registry import Registry
BACKBONE_REGISTRY = Registry(“BACKBONE”)
BACKBONE_REGISTRY.doc = “”"
通過上述三行註冊裝飾器函數
self._name
DATASET
====================================>
data root path /home/shiyy/nas/all_workspace/ReID/data
================================>
self._name
META_ARCH
================================>
self._name
BACKBONE
================================>
self._name
HEADS
不同數據的註冊函數,得到數據,
自己寫一個照着數據寫法,註冊一下
@ 裝飾器函數的功能
模型結構裝飾器
from ...utils.registry import Registry
BACKBONE_REGISTRY = Registry("BACKBONE")
BACKBONE_REGISTRY.__doc__ = """
def build_backbone(cfg):
"""
Build a backbone from `cfg.MODEL.BACKBONE.NAME`.
Returns:
an instance of :class:`Backbone`
"""
backbone_name = cfg.MODEL.BACKBONE.NAME
#backbone_name = build_resnet_backbone
backbone = BACKBONE_REGISTRY.get(backbone_name)(cfg) #打印出來,=build_resnet_backbone
return backbone
# backbone_name 參數的名字是 有 @ 的那一行,函數的名字,調用到resnet網絡結構中的 函數
from .build import BACKBONE_REGISTRY #resnet.py 有這句導入,讓兩個文件連接在一起
@BACKBONE_REGISTRY.register()
def build_resnet_backbone(cfg):
_BASE_: "../Base-bagtricks.yml"
MODEL:
HEADS:
NUM_CLASSES: 751
'''
再 market1501.py 數據處理函數中
@DATASET_REGISTRY.register() #下面是註冊的數據類
class Market1501(ImageDataset):
通過註冊函數,會調用到這句話
dataset = DATASET_REGISTRY.get(" Market1501")(root=_root, combineall=cfg.DATASETS.COMBINEALL)
開始到 market1501.py 的 Market1501類中,各種類的嵌套,讀取數據
'''
DATASETS: #這裏的名字需要是,數據處理的類名,被註冊的 @DATASET_REGISTRY.register() ,否則找不到
NAMES: (" Market1501",) # from .market1501 import market1501
TESTS: (" Market1501",)
OUTPUT_DIR: "logs/market1501/bagtricks_R50"
register 函數做了什麼
def register(self, obj: object = None) -> Optional[object]:
"""
Register the given object under the the name `obj.__name__`.
Can be used as either a decorator or not. See docstring of this class for usage.
函數功能返回 @類名字, 或者 函數的 import 導入路徑
例如 數據返回 fastreid.data.datasets.merge_market1501.Mergedata_market1501
例如模型結構返回 build_resnet_backbone (resnet 中的def build_resnet_backbone)
print(func_or_class #<class
print(func_or_class.__name__)
打印類 和方法 路徑不同
'fastreid.data.datasets.market1501.Market1501'> #Market1501
<function build_resnet_backbone at 0x7f4bd9e36598> #build_resnet_backbone
"""
if obj is None:
# used as a decorator
def deco(func_or_class: object) -> object:
name = func_or_class.__name__ # pyre-ignore
self._do_register(name, func_or_class)
return func_or_class
return deco
# used as a function call
name = obj.__name__ # pyre-ignore
self._do_register(name, obj)
#下面返回的是 註冊方式得到 函數方法,和類方法,impprt導入路徑
<class 'fastreid.data.datasets.cuhk03.CUHK03'>
777777777777777777777777777777777
<class 'fastreid.data.datasets.dukemtmcreid.DukeMTMC'>
777777777777777777777777777777777
<class 'fastreid.data.datasets.market1501.Market1501'>
777777777777777777777777777777777
<class 'fastreid.data.datasets.merge_market1501.Mergedata_market1501'>
777777777777777777777777777777777
<class 'fastreid.data.datasets.msmt17.MSMT17'>
777777777777777777777777777777777
<class 'fastreid.data.datasets.veri.VeRi'>
777777777777777777777777777777777
<class 'fastreid.data.datasets.vehicleid.VehicleID'>
777777777777777777777777777777777
<class 'fastreid.data.datasets.vehicleid.SmallVehicleID'>
777777777777777777777777777777777
<class 'fastreid.data.datasets.vehicleid.MediumVehicleID'>
777777777777777777777777777777777
<class 'fastreid.data.datasets.vehicleid.LargeVehicleID'>
777777777777777777777777777777777
<class 'fastreid.data.datasets.veriwild.VeRiWild'>
777777777777777777777777777777777
<class 'fastreid.data.datasets.veriwild.SmallVeRiWild'>
777777777777777777777777777777777
<class 'fastreid.data.datasets.veriwild.MediumVeRiWild'>
777777777777777777777777777777777
<class 'fastreid.data.datasets.veriwild.LargeVeRiWild'>
777777777777777777777777777777777
<function build_resnet_backbone at 0x7f0338f5e598>
777777777777777777777777777777777
<function build_osnet_backbone at 0x7f0338f69268>
777777777777777777777777777777777
<function build_resnest_backbone at 0x7f0338f69730>
777777777777777777777777777777777
<function build_resnext_backbone at 0x7f0338f69ae8>
777777777777777777777777777777777
<class 'fastreid.modeling.heads.linear_head.LinearHead'>
777777777777777777777777777777777
<class 'fastreid.modeling.heads.bnneck_head.BNneckHead'>
777777777777777777777777777777777
<class 'fastreid.modeling.heads.reduction_head.ReductionHead'>
777777777777777777777777777777777
<class 'fastreid.modeling.meta_arch.baseline.Baseline'>
777777777777777777777777777777777
<class 'fastreid.modeling.meta_arch.mgn.MGN'>
777777777777777777777777777777777
3、訓練數據,測試數據加載評估 hook 任務事件添加
主要代碼在這裏面,中間的各種任務是類的多級調用
D:\Projects\reid\fast-reid\fastreid\engine\defaults.py
'''
執行順序
四個類,mro
class Trainer
(<class 'fastreid.engine.defaults.DefaultTrainer'>,
<class 'fastreid.engine.train_loop.SimpleTrainer'>,
<class 'fastreid.engine.train_loop.TrainerBase'>,
<class 'object'>)
順序執行super 之前的代碼,原路返回再執行super 之後的代碼
1、super 之前的代碼
self.cfg = cfg
logger = logging.getLogger(__name__)
if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for fastreid
setup_logger()
# Assume these objects must be constructed in this order.
model = self.build_model(cfg)
optimizer = self.build_optimizer(cfg, model)
logger.info('Prepare training set')
data_loader = self.build_train_loader(cfg) #加載數據接口
# For training, wrap with DP. But don't need this for inference.
model = DataParallel(model)
if cfg.MODEL.BACKBONE.NORM == "syncBN":
# Monkey-patching with syncBN
patch_replication_callback(model)
model = model.cuda()
self._hooks = []
2 返回執行super 之後的代碼
#
model.train() #model.eval()對應,訓練模型,不是 (DefaultTrainer 中 def train (): super().train(self.start_iter, self.max_iter)
self.model = model
self.data_loader = data_loader
self._data_loader_iter = iter(data_loader)
self.optimizer = optimizer
self.scheduler = self.build_lr_scheduler(cfg, optimizer)
# Assume no other objects need to be checkpointed.
# We can later make it checkpoint the stateful hooks
self.checkpointer = Checkpointer(
# Assume you want to save checkpoints together with logs/statistics
model,
self.data_loader.dataset,
cfg.OUTPUT_DIR,
optimizer=optimizer,
scheduler=self.scheduler,
)
self.start_iter = 0
if cfg.SOLVER.SWA.ENABLED:
self.max_iter = cfg.SOLVER.MAX_ITER + cfg.SOLVER.SWA.ITER
else:
self.max_iter = cfg.SOLVER.MAX_ITER
self.cfg = cfg
self.register_hooks(self.build_hooks()) #建立多個鉤子函數任務列表,這裏包含了數據測試(測試數據的加載和評估)
print("全部準備就緒準備訓練")
3、 trainer = Trainer(cfg)
trainer.resume_or_load(resume=args.resume)
#上述已經完成了數據的加載,模型的加載,多個函數功能的實現加載
print("這纔是開始訓練代碼")
return trainer.train() #調用訓練代碼
'''