mmdetection源碼筆記(四):訓練模型之train_detector()的解讀

引言

之前在寫mmdetection源碼的解讀過程時,覺得train_detector()這部分很重要,對於理解整個的訓練過程應該時起着非常大的理解作用。
然後最近研究工作一直在看和修改mmdetection的其他模塊的代碼這一塊。感覺train_detector()這塊內容其實也不是特別重要來着,可能就是一個加強理解的過程。這次還是花了點時間,大致的看了一下,順便加上自己的一些理解,解釋了一下整個過程,如果有錯的話,希望各路大佬指出,互相學習哈。

train_detector()

下面的代碼出現在tools/train.py中,也是main函數的結尾,也就是說,我們訓練的時候,到這就是真正的開始訓練了。

 train_detector(
        model,
        datasets,
        cfg,
        distributed=distributed,
        validate=args.validate,
        logger=logger)

那到底怎麼訓練的呢?
下面代碼是train_detector()函數的定義,在mmdet/api/train.py文件中

def train_detector(model,
                   dataset,
                   cfg,
                   distributed=False,
                   validate=False,
                   logger=None):
    if logger is None:
        logger = get_root_logger(cfg.log_level)
    # start training
    if distributed:
        _dist_train(model, dataset, cfg, validate=validate)
    else:
        _non_dist_train(model, dataset, cfg, validate=validate)

上面的開始訓練過程分分佈式訓練和非分佈式訓練兩種方法,我們只說分佈式訓練,同樣下面代碼是_dist_train()的定義,也在mmdet/api/train.py中

def _dist_train(model, dataset, cfg, validate=False):
    # prepare data loaders
    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
    data_loaders = [
        build_dataloader(
            ds, cfg.data.imgs_per_gpu, cfg.data.workers_per_gpu, dist=True)
        for ds in dataset
    ]
    # put model on gpus
    model = MMDistributedDataParallel(model.cuda())

    # build runner 用來爲pytorch訓練用的類,該類在mmcv/mmcv/runner/runner.py中
    optimizer = build_optimizer(model, cfg.optimizer)
    # Optimizer 是用來更新和計算影響模型訓練和模型輸出的網絡參數,使其逼近或達到最優值,從而最小化(或最大化)損失函數E(x)
    # 這種算法使用各參數的梯度值來最小化或最大化損失函數E(x)。最常用的一階優化算法是梯度下降。
    runner = Runner(model, batch_processor, optimizer, cfg.work_dir,
                    cfg.log_level)

    # fp16 setting   用來提速的
    fp16_cfg = cfg.get('fp16', None)
    if fp16_cfg is not None:
        optimizer_config = Fp16OptimizerHook(**cfg.optimizer_config,
                                             **fp16_cfg)
    else:
        optimizer_config = DistOptimizerHook(**cfg.optimizer_config)

    # register hooks hooks 用來查看中間變量的
    # hook的作用是,當反傳時,除了完成原有的反傳,額外多完成一些任務。你可以定義一箇中間變量的hook,將它的grad值打印出來,當然你也可以定義一個全局列表,將每次的grad值添加到裏面去。
    # 下面的hooks也是一樣的,具體pytorch中hooks的作用,可以參考下方鏈接
    runner.register_training_hooks(cfg.lr_config, optimizer_config,
                                   cfg.checkpoint_config, cfg.log_config)
    runner.register_hook(DistSamplerSeedHook())
    # register eval hooks
    if validate:
        val_dataset_cfg = cfg.data.val
        eval_cfg = cfg.get('evaluation', {})
        if isinstance(model.module, RPN):
            # TODO: implement recall hooks for other datasets
            runner.register_hook(
                CocoDistEvalRecallHook(val_dataset_cfg, **eval_cfg))
        else:
            dataset_type = DATASETS.get(val_dataset_cfg.type)
            if issubclass(dataset_type, datasets.CocoDataset):
                runner.register_hook(
                    CocoDistEvalmAPHook(val_dataset_cfg, **eval_cfg))
            else:
                runner.register_hook(
                    DistEvalmAPHook(val_dataset_cfg, **eval_cfg))

    if cfg.resume_from: # 從resume_from(checkpoint)重新開始訓練?? 
    # (resume_from的作用我猜的,可以自己細看這部分的代碼)
        runner.resume(cfg.resume_from)
    elif cfg.load_from: # 加載 checkpoint,繼續訓練
        runner.load_checkpoint(cfg.load_from)
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs) # 開始訓練

上面代碼,還出現了一個類runner,這個類的作用呢,就是用來更好的訓練pytorch模型的。
簡單的說,就是用runner這個類來操控安排訓練過程中的各個環節。 這個操控包括,要在module中獲取中間變量啊,或者加載和保存檢查點,或者啓動訓練、啓動測試、或者初始化權重,本身這個函數是不能改變這個網絡模型的各個部分的,也就是說,我們要真正修改backbone、或者FPN啊,或者分類迴歸的具體實現,跟這個類無關。
也就是說,你只要把你定義好的網絡模型結構,加載好的數據集,你要的優化器等,扔給runner,他就會來幫你跑模型。
runner這個類定義在mmcv/mmcv/runner/runner.py中,裏面好多方法,想要了解的可以自己慢慢去看。

所以train_detection()這一部分的作用,其實就是幫我們把之前設計好的網絡結構,數據集等,扔給runner,然後就行了,具體怎麼跑呢,不需要太轉牛角尖,畢竟太黑盒了。

如果以上理解有誤,請指出,互相學習哈!

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章