【庖丁解牛】從零實現RetinaNet(六):RetinaNet的訓練與測試

所有代碼已上傳到本人github repository:https://github.com/zgcr/pytorch-ImageNet-CIFAR-COCO-VOC-training
如果覺得有用,請點個star喲!
下列代碼均在pytorch1.4版本中測試過,確認正確無誤。

在從零實現RetinaNet(一)到(五)中我已經完整復現了RetinaNet。這個復現的思路主要是把目標檢測器分成三個獨立的部分:前向網絡、loss計算、decode解碼。仔細查看可以發現在loss部分和decode部分實際上存在一定的重複代碼,但我沒有將重複的代碼摘出來進行代碼複用,這主要是爲了實現三個部分的高內聚和低耦合,這樣我們可以應用目前最新的目標檢測方法對三個獨立部分進行改變,然後像搭積木一樣搭建改進後的目標檢測器。下面我們就可以開始訓練和測試RetinaNet了。

RetinaNet的訓練

在RetinaNet論文(https://arxiv.org/pdf/1708.02002.pdf)中,標準的訓練方法是這樣的:使用momentum=0.9,weight_decay=0.0001的SGD優化器,batch_size=16且使用跨卡同步BN。一共迭代90000次,初始學習率爲0.01,在60000和80000次分別將學習率除以10。
我的訓練過程與上面稍有不同,但區別不大。通過16乘以90000再除以118287(COCO2017_train中圖片的數量)可以計算得到大約是12.17個epoch。因此我們的訓練就訓練到12個epoch爲止。爲了簡單起見,我使用Adam優化器,這樣可以自動衰減學習率。根據以往經驗,Adam優化器在初期收斂速度要比sgd要快,但最終訓練的結果要略差於sgd(收斂的局部最優點沒有sgd的好),但差距很小,對於RetinaNet,一般mAP差距最多不會超過0.5個百分點。
在Detectron和Detectron2框架中,上面所述RetinaNet論文的標準訓練方法又被叫做1x_training。類似地,將迭代次數和學習率衰減的迭代次數index乘以2和乘以3,就叫做2x_training和3x_training。

在COCO數據集上測試RetinaNet

在COCO上測試RetinaNet的性能我們可以直接使用pycocotools.cocoeval中的COCOeval類提供的API。我們只需要將RetinaNet類的前向計算結果(包含anchor)送入RetinaDecoder類進行解碼,然後將解碼後的bbox按照scale放大成在原始圖像上的尺寸即可(因爲解碼後的bbox尺寸大小是相對於resize後的圖像大小的)。然後我們過濾掉每個圖像探測到的目標中的無效目標(class_index爲-1的),按照一定的格式寫入一個josn文件中,再調用COCOeval進行計算就可以了。

COCOeval類提供12個性能指標:

self.maxDets = [1, 10, 100] # 即decoder中提到的max_detection_num
stats[0] = _summarize(1)
stats[1] = _summarize(1, iouThr=.5, maxDets=self.params.maxDets[2])
stats[2] = _summarize(1, iouThr=.75, maxDets=self.params.maxDets[2])
stats[3] = _summarize(1, areaRng='small', maxDets=self.params.maxDets[2])
stats[4] = _summarize(1, areaRng='medium', maxDets=self.params.maxDets[2])
stats[5] = _summarize(1, areaRng='large', maxDets=self.params.maxDets[2])
stats[6] = _summarize(0, maxDets=self.params.maxDets[0])
stats[7] = _summarize(0, maxDets=self.params.maxDets[1])
stats[8] = _summarize(0, maxDets=self.params.maxDets[2])
stats[9] = _summarize(0, areaRng='small', maxDets=self.params.maxDets[2])
stats[10] = _summarize(0, areaRng='medium', maxDets=self.params.maxDets[2])
stats[11] = _summarize(0, areaRng='large', maxDets=self.params.maxDets[2])

各結果的含義如下:

 # 無特殊說明的情況下,一般目標檢測論文中所說的模型在COCO上的性能指的就是stats[0],至於是coco2017_val集還是coco2017_test集要看論文中描述,不過兩者一般只相差0.2~0.5個百分點
stats[0] : IoU=0.5:0.95,area=all,maxDets=100,mAP
stats[1] : IoU=0.5,area=all,maxDets=100,mAP
stats[2] : IoU=0.75,area=all,maxDets=100,mAP
stats[3] : IoU=0.5:0.95,area=small,maxDets=100,mAP
stats[4] : IoU=0.5:0.95,area=medium,maxDets=100,mAP
stats[5] : IoU=0.5:0.95,area=large,maxDets=100,mAP
stats[6] : IoU=0.5:0.95,area=all,maxDets=1,mAR
stats[7] : IoU=0.5:0.95,area=all,maxDets=10,mAR
stats[8] : IoU=0.5:0.95,area=all,maxDets=100,mAR
stats[9] : IoU=0.5:0.95,area=small,maxDets=100,mAR
stats[10]:IoU=0.5:0.95,area=medium,maxDets=100,mAR
stats[11]:IoU=0.5:0.95,area=large,maxDets=100,mAR

在COCO數據集上測試的代碼實現如下:

def validate(val_dataset, model, decoder):
    model = model.module
    # switch to evaluate mode
    model.eval()
    with torch.no_grad():
        all_eval_result = evaluate_coco(val_dataset, model, decoder)

    return all_eval_result


def evaluate_coco(val_dataset, model, decoder):
    results, image_ids = [], []
    for index in range(len(val_dataset)):
        data = val_dataset[index]
        scale = data['scale']
        cls_heads, reg_heads, batch_anchors = model(data['img'].cuda().permute(
            2, 0, 1).float().unsqueeze(dim=0))
        scores, classes, boxes = decoder(cls_heads, reg_heads, batch_anchors)
        scores, classes, boxes = scores.cpu(), classes.cpu(), boxes.cpu()
        boxes /= scale

        # make sure decode batch_size=1
        # scores shape:[1,max_detection_num]
        # classes shape:[1,max_detection_num]
        # bboxes shape[1,max_detection_num,4]
        assert scores.shape[0] == 1

        scores = scores.squeeze(0)
        classes = classes.squeeze(0)
        boxes = boxes.squeeze(0)

        # for coco_eval,we need [x_min,y_min,w,h] format pred boxes
        boxes[:, 2:] -= boxes[:, :2]

        for object_score, object_class, object_box in zip(
                scores, classes, boxes):
            object_score = float(object_score)
            object_class = int(object_class)
            object_box = object_box.tolist()
            if object_class == -1:
                break

            image_result = {
                'image_id':
                val_dataset.image_ids[index],
                'category_id':
                val_dataset.find_category_id_from_coco_label(object_class),
                'score':
                object_score,
                'bbox':
                object_box,
            }
            results.append(image_result)

        image_ids.append(val_dataset.image_ids[index])

        print('{}/{}'.format(index, len(val_dataset)), end='\r')

    if not len(results):
        print("No target detected in test set images")
        return

    json.dump(results,
              open('{}_bbox_results.json'.format(val_dataset.set_name), 'w'),
              indent=4)

    # load results in COCO evaluation tool
    coco_true = val_dataset.coco
    coco_pred = coco_true.loadRes('{}_bbox_results.json'.format(
        val_dataset.set_name))

    coco_eval = COCOeval(coco_true, coco_pred, 'bbox')
    coco_eval.params.imgIds = image_ids
    coco_eval.evaluate()
    coco_eval.accumulate()
    coco_eval.summarize()
    all_eval_result = coco_eval.stats

    return all_eval_result

在COCO數據集上訓練和測試時,我們遵循RetinaNet論文中的數據集設置,使用coco_2017_train數據集訓練模型,使用coco_2017_val數據集測試模型。使用IoU=0.5:0.95下,最多保留100個detect目標,保留所有大小的目標下的mAP(即pycocotools.cocoeval的COCOeval類中_summarizeDets函數中的stats[0]值)作爲模型的性能表現。

在VOC數據集上測試RetinaNet

在VOC數據集上訓練和測試時,我們參照detectron2中使用faster rcnn在VOC數據集上訓練測試的做法(https://github.com/facebookresearch/detectron2/blob/master/MODEL_ZOO.md),使用VOC2007trainval+VOC2012trainval數據集訓練模型,使用VOC2007test數據集測試模型。測試時使用VOC2007的11 point metric方式計算mAP。

測試代碼使用經典的VOC測試代碼,只是把輸入和輸出做了一下適配:

def compute_voc_ap(recall, precision, use_07_metric=True):
    if use_07_metric:
        # use voc 2007 11 point metric
        ap = 0.
        for t in np.arange(0., 1.1, 0.1):
            if np.sum(recall >= t) == 0:
                p = 0
            else:
                # get max precision  for recall >= t
                p = np.max(precision[recall >= t])
            # average 11 recall point precision
            ap = ap + p / 11.
    else:
        # use voc>=2010 metric,average all different recall precision as ap
        # recall add first value 0. and last value 1.
        mrecall = np.concatenate(([0.], recall, [1.]))
        # precision add first value 0. and last value 0.
        mprecision = np.concatenate(([0.], precision, [0.]))

        # compute the precision envelope
        for i in range(mprecision.size - 1, 0, -1):
            mprecision[i - 1] = np.maximum(mprecision[i - 1], mprecision[i])

        # to calculate area under PR curve, look for points where X axis (recall) changes value
        i = np.where(mrecall[1:] != mrecall[:-1])[0]

        # sum (\Delta recall) * prec
        ap = np.sum((mrecall[i + 1] - mrecall[i]) * mprecision[i + 1])

    return ap


def compute_ious(a, b):
    """
    :param a: [N,(x1,y1,x2,y2)]
    :param b: [M,(x1,y1,x2,y2)]
    :return:  IoU [N,M]
    """

    a = np.expand_dims(a, axis=1)  # [N,1,4]
    b = np.expand_dims(b, axis=0)  # [1,M,4]

    overlap = np.maximum(0.0,
                         np.minimum(a[..., 2:], b[..., 2:]) -
                         np.maximum(a[..., :2], b[..., :2]))  # [N,M,(w,h)]

    overlap = np.prod(overlap, axis=-1)  # [N,M]

    area_a = np.prod(a[..., 2:] - a[..., :2], axis=-1)
    area_b = np.prod(b[..., 2:] - b[..., :2], axis=-1)

    iou = overlap / (area_a + area_b - overlap)

    return iou


def validate(val_dataset, model, decoder):
    model = model.module
    # switch to evaluate mode
    model.eval()
    with torch.no_grad():
        all_ap, mAP = evaluate_voc(val_dataset,
                                   model,
                                   decoder,
                                   num_classes=20,
                                   iou_thread=0.5)

    return all_ap, mAP


def evaluate_voc(val_dataset, model, decoder, num_classes=20, iou_thread=0.5):
    preds, gts = [], []
    for index in tqdm(range(len(val_dataset))):
        data = val_dataset[index]
        img, gt_annot, scale = data['img'], data['annot'], data['scale']

        gt_bboxes, gt_classes = gt_annot[:, 0:4], gt_annot[:, 4]
        gt_bboxes /= scale

        gts.append([gt_bboxes, gt_classes])

        cls_heads, reg_heads, batch_anchors = model(img.cuda().permute(
            2, 0, 1).float().unsqueeze(dim=0))
        preds_scores, preds_classes, preds_boxes = decoder(
            cls_heads, reg_heads, batch_anchors)
        preds_scores, preds_classes, preds_boxes = preds_scores.cpu(
        ), preds_classes.cpu(), preds_boxes.cpu()
        preds_boxes /= scale

        # make sure decode batch_size=1
        # preds_scores shape:[1,max_detection_num]
        # preds_classes shape:[1,max_detection_num]
        # preds_bboxes shape[1,max_detection_num,4]
        assert preds_scores.shape[0] == 1

        preds_scores = preds_scores.squeeze(0)
        preds_classes = preds_classes.squeeze(0)
        preds_boxes = preds_boxes.squeeze(0)

        preds_scores = preds_scores[preds_classes > -1]
        preds_boxes = preds_boxes[preds_classes > -1]
        preds_classes = preds_classes[preds_classes > -1]

        preds.append([preds_boxes, preds_classes, preds_scores])

    print("all val sample decode done.")

    all_ap = {}
    for class_index in tqdm(range(num_classes)):
        per_class_gt_boxes = [
            image[0][image[1] == class_index] for image in gts
        ]
        per_class_pred_boxes = [
            image[0][image[1] == class_index] for image in preds
        ]
        per_class_pred_scores = [
            image[2][image[1] == class_index] for image in preds
        ]

        fp = np.zeros((0, ))
        tp = np.zeros((0, ))
        scores = np.zeros((0, ))
        total_gts = 0

        # loop for each sample
        for per_image_gt_boxes, per_image_pred_boxes, per_image_pred_scores in zip(
                per_class_gt_boxes, per_class_pred_boxes,
                per_class_pred_scores):
            total_gts = total_gts + len(per_image_gt_boxes)
            # one gt can only be assigned to one predicted bbox
            assigned_gt = []
            # loop for each predicted bbox
            for index in range(len(per_image_pred_boxes)):
                scores = np.append(scores, per_image_pred_scores[index])
                if per_image_gt_boxes.shape[0] == 0:
                    # if no gts found for the predicted bbox, assign the bbox to fp
                    fp = np.append(fp, 1)
                    tp = np.append(tp, 0)
                    continue
                pred_box = np.expand_dims(per_image_pred_boxes[index], axis=0)
                iou = compute_ious(per_image_gt_boxes, pred_box)
                gt_for_box = np.argmax(iou, axis=0)
                max_overlap = iou[gt_for_box, 0]
                if max_overlap >= iou_thread and gt_for_box not in assigned_gt:
                    fp = np.append(fp, 0)
                    tp = np.append(tp, 1)
                    assigned_gt.append(gt_for_box)
                else:
                    fp = np.append(fp, 1)
                    tp = np.append(tp, 0)
        # sort by score
        indices = np.argsort(-scores)
        fp = fp[indices]
        tp = tp[indices]
        # compute cumulative false positives and true positives
        fp = np.cumsum(fp)
        tp = np.cumsum(tp)
        # compute recall and precision
        recall = tp / total_gts
        precision = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
        ap = compute_voc_ap(recall, precision)
        all_ap[class_index] = ap

    mAP = 0.
    for _, class_mAP in all_ap.items():
        mAP += float(class_mAP)
    mAP /= num_classes

    return all_ap, mAP

請注意compute_voc_ap函數中use_07_metric=True表示使用VOC2007的11 point metric方式計算mAP,use_07_metric=False表示使用VOC2010之後新的mAP計算方式。

完整訓練與測試代碼

我們在訓練中一共訓練12個epoch,每5個epoch測試一次模型性能表現,訓練完成時也測試一次模型性能表現。
完整訓練與測試代碼實現如下(這裏是在COCO數據集上的訓練與測試代碼,要在VOC數據集上訓練與測試只要稍作修改即可)。

config.py文件:

import os
import sys

BASE_DIR = os.path.dirname(
    os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(BASE_DIR)

from public.path import COCO2017_path
from public.detection.dataset.cocodataset import CocoDetection, Resize, RandomFlip, RandomCrop, RandomTranslate

import torchvision.transforms as transforms
import torchvision.datasets as datasets


class Config(object):
    log = './log'  # Path to save log
    checkpoint_path = './checkpoints'  # Path to store checkpoint model
    resume = './checkpoints/latest.pth'  # load checkpoint model
    evaluate = None  # evaluate model path
    train_dataset_path = os.path.join(COCO2017_path, 'images/train2017')
    val_dataset_path = os.path.join(COCO2017_path, 'images/val2017')
    dataset_annotations_path = os.path.join(COCO2017_path, 'annotations')

    network = "resnet50_retinanet"
    pretrained = False
    num_classes = 80
    seed = 0
    input_image_size = 667

    train_dataset = CocoDetection(image_root_dir=train_dataset_path,
                                  annotation_root_dir=dataset_annotations_path,
                                  set="train2017",
                                  transform=transforms.Compose([
                                      RandomFlip(flip_prob=0.5),
                                      RandomCrop(crop_prob=0.5),
                                      RandomTranslate(translate_prob=0.5),
                                      Resize(resize=input_image_size),
                                  ]))
    val_dataset = CocoDetection(image_root_dir=val_dataset_path,
                                annotation_root_dir=dataset_annotations_path,
                                set="val2017",
                                transform=transforms.Compose([
                                    Resize(resize=input_image_size),
                                ]))

    epochs = 12
    batch_size = 16
    lr = 1e-5
    num_workers = 4
    print_interval = 100
    apex = True

train.py文件:

import sys
import os
import argparse
import random
import shutil
import time
import warnings
import json

BASE_DIR = os.path.dirname(
    os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(BASE_DIR)
warnings.filterwarnings('ignore')

import numpy as np
from thop import profile
from thop import clever_format
from apex import amp
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torchvision import transforms
from config import Config
from public.detection.dataset.cocodataset import COCODataPrefetcher, collater
from public.detection.models.loss import RetinaLoss
from public.detection.models.decode import RetinaDecoder
from public.detection.models.retinanet import resnet50_retinanet
from public.imagenet.utils import get_logger
from pycocotools.cocoeval import COCOeval


def parse_args():
    parser = argparse.ArgumentParser(
        description='PyTorch COCO Detection Training')
    parser.add_argument('--network',
                        type=str,
                        default=Config.network,
                        help='name of network')
    parser.add_argument('--lr',
                        type=float,
                        default=Config.lr,
                        help='learning rate')
    parser.add_argument('--epochs',
                        type=int,
                        default=Config.epochs,
                        help='num of training epochs')
    parser.add_argument('--batch_size',
                        type=int,
                        default=Config.batch_size,
                        help='batch size')
    parser.add_argument('--pretrained',
                        type=bool,
                        default=Config.pretrained,
                        help='load pretrained model params or not')
    parser.add_argument('--num_classes',
                        type=int,
                        default=Config.num_classes,
                        help='model classification num')
    parser.add_argument('--input_image_size',
                        type=int,
                        default=Config.input_image_size,
                        help='input image size')
    parser.add_argument('--num_workers',
                        type=int,
                        default=Config.num_workers,
                        help='number of worker to load data')
    parser.add_argument('--resume',
                        type=str,
                        default=Config.resume,
                        help='put the path to resuming file if needed')
    parser.add_argument('--checkpoints',
                        type=str,
                        default=Config.checkpoint_path,
                        help='path for saving trained models')
    parser.add_argument('--log',
                        type=str,
                        default=Config.log,
                        help='path to save log')
    parser.add_argument('--evaluate',
                        type=str,
                        default=Config.evaluate,
                        help='path for evaluate model')
    parser.add_argument('--seed', type=int, default=Config.seed, help='seed')
    parser.add_argument('--print_interval',
                        type=bool,
                        default=Config.print_interval,
                        help='print interval')
    parser.add_argument('--apex',
                        type=bool,
                        default=Config.apex,
                        help='use apex or not')

    return parser.parse_args()


def validate(val_dataset, model, decoder):
    model = model.module
    # switch to evaluate mode
    model.eval()
    with torch.no_grad():
        all_eval_result = evaluate_coco(val_dataset, model, decoder)

    return all_eval_result


def evaluate_coco(val_dataset, model, decoder):
    results, image_ids = [], []
    for index in range(len(val_dataset)):
        data = val_dataset[index]
        scale = data['scale']
        cls_heads, reg_heads, batch_anchors = model(data['img'].cuda().permute(
            2, 0, 1).float().unsqueeze(dim=0))
        scores, classes, boxes = decoder(cls_heads, reg_heads, batch_anchors)
        scores, classes, boxes = scores.cpu(), classes.cpu(), boxes.cpu()
        boxes /= scale

        # make sure decode batch_size=1
        # scores shape:[1,max_detection_num]
        # classes shape:[1,max_detection_num]
        # bboxes shape[1,max_detection_num,4]
        assert scores.shape[0] == 1

        scores = scores.squeeze(0)
        classes = classes.squeeze(0)
        boxes = boxes.squeeze(0)

        # for coco_eval,we need [x_min,y_min,w,h] format pred boxes
        boxes[:, 2:] -= boxes[:, :2]

        for object_score, object_class, object_box in zip(
                scores, classes, boxes):
            object_score = float(object_score)
            object_class = int(object_class)
            object_box = object_box.tolist()
            if object_class == -1:
                break

            image_result = {
                'image_id':
                val_dataset.image_ids[index],
                'category_id':
                val_dataset.find_category_id_from_coco_label(object_class),
                'score':
                object_score,
                'bbox':
                object_box,
            }
            results.append(image_result)

        image_ids.append(val_dataset.image_ids[index])

        print('{}/{}'.format(index, len(val_dataset)), end='\r')

    if not len(results):
        print("No target detected in test set images")
        return

    json.dump(results,
              open('{}_bbox_results.json'.format(val_dataset.set_name), 'w'),
              indent=4)

    # load results in COCO evaluation tool
    coco_true = val_dataset.coco
    coco_pred = coco_true.loadRes('{}_bbox_results.json'.format(
        val_dataset.set_name))

    coco_eval = COCOeval(coco_true, coco_pred, 'bbox')
    coco_eval.params.imgIds = image_ids
    coco_eval.evaluate()
    coco_eval.accumulate()
    coco_eval.summarize()
    all_eval_result = coco_eval.stats

    return all_eval_result


def train(train_loader, model, criterion, optimizer, scheduler, epoch, logger,
          args):
    cls_losses, reg_losses, losses = [], [], []

    # switch to train mode
    model.train()

    iters = len(train_loader.dataset) // args.batch_size
    prefetcher = COCODataPrefetcher(train_loader)
    images, annotations = prefetcher.next()
    iter_index = 1

    while images is not None:
        images, annotations = images.cuda().float(), annotations.cuda()
        cls_heads, reg_heads, batch_anchors = model(images)
        cls_loss, reg_loss = criterion(cls_heads, reg_heads, batch_anchors,
                                       annotations)
        loss = cls_loss + reg_loss
        if cls_loss == 0.0 or reg_loss == 0.0:
            optimizer.zero_grad()
            continue

        if args.apex:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()
        optimizer.zero_grad()

        cls_losses.append(cls_loss.item())
        reg_losses.append(reg_loss.item())
        losses.append(loss.item())

        images, annotations = prefetcher.next()

        if iter_index % args.print_interval == 0:
            logger.info(
                f"train: epoch {epoch:0>3d}, iter [{iter_index:0>5d}, {iters:0>5d}], cls_loss: {cls_loss.item():.2f}, reg_loss: {reg_loss.item():.2f}, loss_total: {loss.item():.2f}"
            )

        iter_index += 1

    scheduler.step(np.mean(losses))

    return np.mean(cls_losses), np.mean(reg_losses), np.mean(losses)


def main(logger, args):
    if not torch.cuda.is_available():
        raise Exception("need gpu to train network!")

    torch.cuda.empty_cache()

    if args.seed is not None:
        random.seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        cudnn.deterministic = True

    gpus = torch.cuda.device_count()
    logger.info(f'use {gpus} gpus')
    logger.info(f"args: {args}")

    cudnn.benchmark = True
    cudnn.enabled = True
    start_time = time.time()

    # dataset and dataloader
    logger.info('start loading data')
    train_loader = DataLoader(Config.train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers,
                              collate_fn=collater)
    logger.info('finish loading data')

    model = resnet50_retinanet(**{
        "pretrained": args.pretrained,
        "num_classes": args.num_classes,
    })

    for name, param in model.named_parameters():
        logger.info(f"{name},{param.requires_grad}")

    flops_input = torch.randn(1, 3, args.input_image_size,
                              args.input_image_size)
    flops, params = profile(model, inputs=(flops_input, ))
    flops, params = clever_format([flops, params], "%.3f")
    logger.info(f"model: '{args.network}', flops: {flops}, params: {params}")

    criterion = RetinaLoss(image_w=args.input_image_size,
                           image_h=args.input_image_size).cuda()
    decoder = RetinaDecoder(image_w=args.input_image_size,
                            image_h=args.input_image_size).cuda()

    model = model.cuda()
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           patience=3,
                                                           verbose=True)

    if args.apex:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')

    model = nn.DataParallel(model)

    if args.evaluate:
        if not os.path.isfile(args.evaluate):
            raise Exception(
                f"{args.resume} is not a file, please check it again")
        logger.info('start only evaluating')
        logger.info(f"start resuming model from {args.evaluate}")
        checkpoint = torch.load(args.evaluate,
                                map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint['model_state_dict'])
        all_eval_result = validate(Config.val_dataset, model, decoder)
        if all_eval_result is not None:
            logger.info(
                f"val: epoch: {checkpoint['epoch']:0>5d}, IoU=0.5:0.95,area=all,maxDets=100,mAP:{all_eval_result[0]:.3f}, IoU=0.5,area=all,maxDets=100,mAP:{all_eval_result[1]:.3f}, IoU=0.75,area=all,maxDets=100,mAP:{all_eval_result[2]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAP:{all_eval_result[3]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAP:{all_eval_result[4]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAP:{all_eval_result[5]:.3f}, IoU=0.5:0.95,area=all,maxDets=1,mAR:{all_eval_result[6]:.3f}, IoU=0.5:0.95,area=all,maxDets=10,mAR:{all_eval_result[7]:.3f}, IoU=0.5:0.95,area=all,maxDets=100,mAR:{all_eval_result[8]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAR:{all_eval_result[9]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAR:{all_eval_result[10]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAR:{all_eval_result[11]:.3f}"
            )

        return

    best_map = 0.0
    start_epoch = 1
    # resume training
    if os.path.exists(args.resume):
        logger.info(f"start resuming model from {args.resume}")
        checkpoint = torch.load(args.resume, map_location=torch.device('cpu'))
        start_epoch += checkpoint['epoch']
        best_map = checkpoint['best_map']
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        logger.info(
            f"finish resuming model from {args.resume}, epoch {checkpoint['epoch']}, best_map: {checkpoint['best_map']}, "
            f"loss: {checkpoint['loss']:3f}, cls_loss: {checkpoint['cls_loss']:2f}, reg_loss: {checkpoint['reg_loss']:2f}"
        )

    if not os.path.exists(args.checkpoints):
        os.makedirs(args.checkpoints)

    logger.info('start training')
    for epoch in range(start_epoch, args.epochs + 1):
        cls_losses, reg_losses, losses = train(train_loader, model, criterion,
                                               optimizer, scheduler, epoch,
                                               logger, args)
        logger.info(
            f"train: epoch {epoch:0>3d}, cls_loss: {cls_losses:.2f}, reg_loss: {reg_losses:.2f}, loss: {losses:.2f}"
        )

        if epoch % 5 == 0 or epoch == args.epochs:
            all_eval_result = validate(Config.val_dataset, model, args,
                                       decoder)
            logger.info(f"eval done.")
            if all_eval_result is not None:
                logger.info(
                    f"val: epoch: {epoch:0>5d}, IoU=0.5:0.95,area=all,maxDets=100,mAP:{all_eval_result[0]:.3f}, IoU=0.5,area=all,maxDets=100,mAP:{all_eval_result[1]:.3f}, IoU=0.75,area=all,maxDets=100,mAP:{all_eval_result[2]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAP:{all_eval_result[3]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAP:{all_eval_result[4]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAP:{all_eval_result[5]:.3f}, IoU=0.5:0.95,area=all,maxDets=1,mAR:{all_eval_result[6]:.3f}, IoU=0.5:0.95,area=all,maxDets=10,mAR:{all_eval_result[7]:.3f}, IoU=0.5:0.95,area=all,maxDets=100,mAR:{all_eval_result[8]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAR:{all_eval_result[9]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAR:{all_eval_result[10]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAR:{all_eval_result[11]:.3f}"
                )
                if all_eval_result[0] > best_map:
                    torch.save(model.module.state_dict(),
                               os.path.join(args.checkpoints, "best.pth"))
                    best_map = all_eval_result[0]
        torch.save(
            {
                'epoch': epoch,
                'best_map': best_map,
                'cls_loss': cls_losses,
                'reg_loss': reg_losses,
                'loss': losses,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
            }, os.path.join(args.checkpoints, 'latest.pth'))

    logger.info(f"finish training, best_map: {best_map:.3f}")
    training_time = (time.time() - start_time) / 3600
    logger.info(
        f"finish training, total training time: {training_time:.2f} hours")


if __name__ == '__main__':
    args = parse_args()
    logger = get_logger(__name__, args.log)
    main(logger, args)

上面實現的是在nn.DataParallel模式下的訓練,分佈式訓練方法我會在下一篇文章中實現。要想進行訓練,只需python train.py即可。

模型復現情況評估

根據六篇文章在各個方面對RetinaNet的復現方法,目前要與論文中RetinaNet模型點數對上還有三個問題:

  1. Detectron和Detectron2中使用的ResNet50在ImageNet上的預訓練模型參數是他們自己訓練的,該預訓練模型參數可能要比我的ResNet50預訓練模型參數要好(我的ResNet50與訓練模型表現爲top1-error=23.488%)。根據以往經驗,預訓練模型的表現越好,finetune後的結果就會更好(兩者不是線性關係,但是是正相關關係)。
  2. 上面的訓練使用的是nn.parallel模式,這個模式下不能使用跨卡同步BN,而Detectron和Detectron2中均使用分佈式訓練+跨卡同步BN。在沒有同步BN的情況下,BN只能根據單張卡上batchsize的數據更新均值和標準差,如果單卡上的batchsize小於論文中的16的話,會造成BN訓練的沒有使用跨卡同步BN時準,模型表現會因此有所下降。
  3. 由於我沒有看完Detectron和Detectron2中的全部代碼,這兩個框架中可能還有我未發現的有助於訓練的改進手段。

對於問題1,目前無法去驗證。對於問題2,我們將在下一章使用分佈式訓練+跨卡同步BN訓練解決。對於問題3,我現在也沒有精力去看完Detectron和Detectron2的全部代碼,歡迎小夥伴們提出指正。

模型在COCO數據集上的性能表現如下(batchsize爲16,輸入分辨率爲600,約等於RetinaNet論文中的分辨率450):(正在訓練中,這兩天會更新)

Network epoch5-mAP epoch5-mAP epoch12-mAP
ResNet50-RetinaNet-only-flip
ResNet50-RetinaNet-flip-crop-translate
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章