【庖丁解牛】从零实现RetinaNet(六):RetinaNet的训练与测试

所有代码已上传到本人github repository:https://github.com/zgcr/pytorch-ImageNet-CIFAR-COCO-VOC-training
如果觉得有用,请点个star哟!
下列代码均在pytorch1.4版本中测试过,确认正确无误。

在从零实现RetinaNet(一)到(五)中我已经完整复现了RetinaNet。这个复现的思路主要是把目标检测器分成三个独立的部分:前向网络、loss计算、decode解码。仔细查看可以发现在loss部分和decode部分实际上存在一定的重复代码,但我没有将重复的代码摘出来进行代码复用,这主要是为了实现三个部分的高内聚和低耦合,这样我们可以应用目前最新的目标检测方法对三个独立部分进行改变,然后像搭积木一样搭建改进后的目标检测器。下面我们就可以开始训练和测试RetinaNet了。

RetinaNet的训练

在RetinaNet论文(https://arxiv.org/pdf/1708.02002.pdf)中,标准的训练方法是这样的:使用momentum=0.9,weight_decay=0.0001的SGD优化器,batch_size=16且使用跨卡同步BN。一共迭代90000次,初始学习率为0.01,在60000和80000次分别将学习率除以10。
我的训练过程与上面稍有不同,但区别不大。通过16乘以90000再除以118287(COCO2017_train中图片的数量)可以计算得到大约是12.17个epoch。因此我们的训练就训练到12个epoch为止。为了简单起见,我使用Adam优化器,这样可以自动衰减学习率。根据以往经验,Adam优化器在初期收敛速度要比sgd要快,但最终训练的结果要略差于sgd(收敛的局部最优点没有sgd的好),但差距很小,对于RetinaNet,一般mAP差距最多不会超过0.5个百分点。
在Detectron和Detectron2框架中,上面所述RetinaNet论文的标准训练方法又被叫做1x_training。类似地,将迭代次数和学习率衰减的迭代次数index乘以2和乘以3,就叫做2x_training和3x_training。

在COCO数据集上测试RetinaNet

在COCO上测试RetinaNet的性能我们可以直接使用pycocotools.cocoeval中的COCOeval类提供的API。我们只需要将RetinaNet类的前向计算结果(包含anchor)送入RetinaDecoder类进行解码,然后将解码后的bbox按照scale放大成在原始图像上的尺寸即可(因为解码后的bbox尺寸大小是相对于resize后的图像大小的)。然后我们过滤掉每个图像探测到的目标中的无效目标(class_index为-1的),按照一定的格式写入一个josn文件中,再调用COCOeval进行计算就可以了。

COCOeval类提供12个性能指标:

self.maxDets = [1, 10, 100] # 即decoder中提到的max_detection_num
stats[0] = _summarize(1)
stats[1] = _summarize(1, iouThr=.5, maxDets=self.params.maxDets[2])
stats[2] = _summarize(1, iouThr=.75, maxDets=self.params.maxDets[2])
stats[3] = _summarize(1, areaRng='small', maxDets=self.params.maxDets[2])
stats[4] = _summarize(1, areaRng='medium', maxDets=self.params.maxDets[2])
stats[5] = _summarize(1, areaRng='large', maxDets=self.params.maxDets[2])
stats[6] = _summarize(0, maxDets=self.params.maxDets[0])
stats[7] = _summarize(0, maxDets=self.params.maxDets[1])
stats[8] = _summarize(0, maxDets=self.params.maxDets[2])
stats[9] = _summarize(0, areaRng='small', maxDets=self.params.maxDets[2])
stats[10] = _summarize(0, areaRng='medium', maxDets=self.params.maxDets[2])
stats[11] = _summarize(0, areaRng='large', maxDets=self.params.maxDets[2])

各结果的含义如下:

 # 无特殊说明的情况下,一般目标检测论文中所说的模型在COCO上的性能指的就是stats[0],至于是coco2017_val集还是coco2017_test集要看论文中描述,不过两者一般只相差0.2~0.5个百分点
stats[0] : IoU=0.5:0.95,area=all,maxDets=100,mAP
stats[1] : IoU=0.5,area=all,maxDets=100,mAP
stats[2] : IoU=0.75,area=all,maxDets=100,mAP
stats[3] : IoU=0.5:0.95,area=small,maxDets=100,mAP
stats[4] : IoU=0.5:0.95,area=medium,maxDets=100,mAP
stats[5] : IoU=0.5:0.95,area=large,maxDets=100,mAP
stats[6] : IoU=0.5:0.95,area=all,maxDets=1,mAR
stats[7] : IoU=0.5:0.95,area=all,maxDets=10,mAR
stats[8] : IoU=0.5:0.95,area=all,maxDets=100,mAR
stats[9] : IoU=0.5:0.95,area=small,maxDets=100,mAR
stats[10]:IoU=0.5:0.95,area=medium,maxDets=100,mAR
stats[11]:IoU=0.5:0.95,area=large,maxDets=100,mAR

在COCO数据集上测试的代码实现如下:

def validate(val_dataset, model, decoder):
    model = model.module
    # switch to evaluate mode
    model.eval()
    with torch.no_grad():
        all_eval_result = evaluate_coco(val_dataset, model, decoder)

    return all_eval_result


def evaluate_coco(val_dataset, model, decoder):
    results, image_ids = [], []
    for index in range(len(val_dataset)):
        data = val_dataset[index]
        scale = data['scale']
        cls_heads, reg_heads, batch_anchors = model(data['img'].cuda().permute(
            2, 0, 1).float().unsqueeze(dim=0))
        scores, classes, boxes = decoder(cls_heads, reg_heads, batch_anchors)
        scores, classes, boxes = scores.cpu(), classes.cpu(), boxes.cpu()
        boxes /= scale

        # make sure decode batch_size=1
        # scores shape:[1,max_detection_num]
        # classes shape:[1,max_detection_num]
        # bboxes shape[1,max_detection_num,4]
        assert scores.shape[0] == 1

        scores = scores.squeeze(0)
        classes = classes.squeeze(0)
        boxes = boxes.squeeze(0)

        # for coco_eval,we need [x_min,y_min,w,h] format pred boxes
        boxes[:, 2:] -= boxes[:, :2]

        for object_score, object_class, object_box in zip(
                scores, classes, boxes):
            object_score = float(object_score)
            object_class = int(object_class)
            object_box = object_box.tolist()
            if object_class == -1:
                break

            image_result = {
                'image_id':
                val_dataset.image_ids[index],
                'category_id':
                val_dataset.find_category_id_from_coco_label(object_class),
                'score':
                object_score,
                'bbox':
                object_box,
            }
            results.append(image_result)

        image_ids.append(val_dataset.image_ids[index])

        print('{}/{}'.format(index, len(val_dataset)), end='\r')

    if not len(results):
        print("No target detected in test set images")
        return

    json.dump(results,
              open('{}_bbox_results.json'.format(val_dataset.set_name), 'w'),
              indent=4)

    # load results in COCO evaluation tool
    coco_true = val_dataset.coco
    coco_pred = coco_true.loadRes('{}_bbox_results.json'.format(
        val_dataset.set_name))

    coco_eval = COCOeval(coco_true, coco_pred, 'bbox')
    coco_eval.params.imgIds = image_ids
    coco_eval.evaluate()
    coco_eval.accumulate()
    coco_eval.summarize()
    all_eval_result = coco_eval.stats

    return all_eval_result

在COCO数据集上训练和测试时,我们遵循RetinaNet论文中的数据集设置,使用coco_2017_train数据集训练模型,使用coco_2017_val数据集测试模型。使用IoU=0.5:0.95下,最多保留100个detect目标,保留所有大小的目标下的mAP(即pycocotools.cocoeval的COCOeval类中_summarizeDets函数中的stats[0]值)作为模型的性能表现。

在VOC数据集上测试RetinaNet

在VOC数据集上训练和测试时,我们参照detectron2中使用faster rcnn在VOC数据集上训练测试的做法(https://github.com/facebookresearch/detectron2/blob/master/MODEL_ZOO.md),使用VOC2007trainval+VOC2012trainval数据集训练模型,使用VOC2007test数据集测试模型。测试时使用VOC2007的11 point metric方式计算mAP。

测试代码使用经典的VOC测试代码,只是把输入和输出做了一下适配:

def compute_voc_ap(recall, precision, use_07_metric=True):
    if use_07_metric:
        # use voc 2007 11 point metric
        ap = 0.
        for t in np.arange(0., 1.1, 0.1):
            if np.sum(recall >= t) == 0:
                p = 0
            else:
                # get max precision  for recall >= t
                p = np.max(precision[recall >= t])
            # average 11 recall point precision
            ap = ap + p / 11.
    else:
        # use voc>=2010 metric,average all different recall precision as ap
        # recall add first value 0. and last value 1.
        mrecall = np.concatenate(([0.], recall, [1.]))
        # precision add first value 0. and last value 0.
        mprecision = np.concatenate(([0.], precision, [0.]))

        # compute the precision envelope
        for i in range(mprecision.size - 1, 0, -1):
            mprecision[i - 1] = np.maximum(mprecision[i - 1], mprecision[i])

        # to calculate area under PR curve, look for points where X axis (recall) changes value
        i = np.where(mrecall[1:] != mrecall[:-1])[0]

        # sum (\Delta recall) * prec
        ap = np.sum((mrecall[i + 1] - mrecall[i]) * mprecision[i + 1])

    return ap


def compute_ious(a, b):
    """
    :param a: [N,(x1,y1,x2,y2)]
    :param b: [M,(x1,y1,x2,y2)]
    :return:  IoU [N,M]
    """

    a = np.expand_dims(a, axis=1)  # [N,1,4]
    b = np.expand_dims(b, axis=0)  # [1,M,4]

    overlap = np.maximum(0.0,
                         np.minimum(a[..., 2:], b[..., 2:]) -
                         np.maximum(a[..., :2], b[..., :2]))  # [N,M,(w,h)]

    overlap = np.prod(overlap, axis=-1)  # [N,M]

    area_a = np.prod(a[..., 2:] - a[..., :2], axis=-1)
    area_b = np.prod(b[..., 2:] - b[..., :2], axis=-1)

    iou = overlap / (area_a + area_b - overlap)

    return iou


def validate(val_dataset, model, decoder):
    model = model.module
    # switch to evaluate mode
    model.eval()
    with torch.no_grad():
        all_ap, mAP = evaluate_voc(val_dataset,
                                   model,
                                   decoder,
                                   num_classes=20,
                                   iou_thread=0.5)

    return all_ap, mAP


def evaluate_voc(val_dataset, model, decoder, num_classes=20, iou_thread=0.5):
    preds, gts = [], []
    for index in tqdm(range(len(val_dataset))):
        data = val_dataset[index]
        img, gt_annot, scale = data['img'], data['annot'], data['scale']

        gt_bboxes, gt_classes = gt_annot[:, 0:4], gt_annot[:, 4]
        gt_bboxes /= scale

        gts.append([gt_bboxes, gt_classes])

        cls_heads, reg_heads, batch_anchors = model(img.cuda().permute(
            2, 0, 1).float().unsqueeze(dim=0))
        preds_scores, preds_classes, preds_boxes = decoder(
            cls_heads, reg_heads, batch_anchors)
        preds_scores, preds_classes, preds_boxes = preds_scores.cpu(
        ), preds_classes.cpu(), preds_boxes.cpu()
        preds_boxes /= scale

        # make sure decode batch_size=1
        # preds_scores shape:[1,max_detection_num]
        # preds_classes shape:[1,max_detection_num]
        # preds_bboxes shape[1,max_detection_num,4]
        assert preds_scores.shape[0] == 1

        preds_scores = preds_scores.squeeze(0)
        preds_classes = preds_classes.squeeze(0)
        preds_boxes = preds_boxes.squeeze(0)

        preds_scores = preds_scores[preds_classes > -1]
        preds_boxes = preds_boxes[preds_classes > -1]
        preds_classes = preds_classes[preds_classes > -1]

        preds.append([preds_boxes, preds_classes, preds_scores])

    print("all val sample decode done.")

    all_ap = {}
    for class_index in tqdm(range(num_classes)):
        per_class_gt_boxes = [
            image[0][image[1] == class_index] for image in gts
        ]
        per_class_pred_boxes = [
            image[0][image[1] == class_index] for image in preds
        ]
        per_class_pred_scores = [
            image[2][image[1] == class_index] for image in preds
        ]

        fp = np.zeros((0, ))
        tp = np.zeros((0, ))
        scores = np.zeros((0, ))
        total_gts = 0

        # loop for each sample
        for per_image_gt_boxes, per_image_pred_boxes, per_image_pred_scores in zip(
                per_class_gt_boxes, per_class_pred_boxes,
                per_class_pred_scores):
            total_gts = total_gts + len(per_image_gt_boxes)
            # one gt can only be assigned to one predicted bbox
            assigned_gt = []
            # loop for each predicted bbox
            for index in range(len(per_image_pred_boxes)):
                scores = np.append(scores, per_image_pred_scores[index])
                if per_image_gt_boxes.shape[0] == 0:
                    # if no gts found for the predicted bbox, assign the bbox to fp
                    fp = np.append(fp, 1)
                    tp = np.append(tp, 0)
                    continue
                pred_box = np.expand_dims(per_image_pred_boxes[index], axis=0)
                iou = compute_ious(per_image_gt_boxes, pred_box)
                gt_for_box = np.argmax(iou, axis=0)
                max_overlap = iou[gt_for_box, 0]
                if max_overlap >= iou_thread and gt_for_box not in assigned_gt:
                    fp = np.append(fp, 0)
                    tp = np.append(tp, 1)
                    assigned_gt.append(gt_for_box)
                else:
                    fp = np.append(fp, 1)
                    tp = np.append(tp, 0)
        # sort by score
        indices = np.argsort(-scores)
        fp = fp[indices]
        tp = tp[indices]
        # compute cumulative false positives and true positives
        fp = np.cumsum(fp)
        tp = np.cumsum(tp)
        # compute recall and precision
        recall = tp / total_gts
        precision = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
        ap = compute_voc_ap(recall, precision)
        all_ap[class_index] = ap

    mAP = 0.
    for _, class_mAP in all_ap.items():
        mAP += float(class_mAP)
    mAP /= num_classes

    return all_ap, mAP

请注意compute_voc_ap函数中use_07_metric=True表示使用VOC2007的11 point metric方式计算mAP,use_07_metric=False表示使用VOC2010之后新的mAP计算方式。

完整训练与测试代码

我们在训练中一共训练12个epoch,每5个epoch测试一次模型性能表现,训练完成时也测试一次模型性能表现。
完整训练与测试代码实现如下(这里是在COCO数据集上的训练与测试代码,要在VOC数据集上训练与测试只要稍作修改即可)。

config.py文件:

import os
import sys

BASE_DIR = os.path.dirname(
    os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(BASE_DIR)

from public.path import COCO2017_path
from public.detection.dataset.cocodataset import CocoDetection, Resize, RandomFlip, RandomCrop, RandomTranslate

import torchvision.transforms as transforms
import torchvision.datasets as datasets


class Config(object):
    log = './log'  # Path to save log
    checkpoint_path = './checkpoints'  # Path to store checkpoint model
    resume = './checkpoints/latest.pth'  # load checkpoint model
    evaluate = None  # evaluate model path
    train_dataset_path = os.path.join(COCO2017_path, 'images/train2017')
    val_dataset_path = os.path.join(COCO2017_path, 'images/val2017')
    dataset_annotations_path = os.path.join(COCO2017_path, 'annotations')

    network = "resnet50_retinanet"
    pretrained = False
    num_classes = 80
    seed = 0
    input_image_size = 667

    train_dataset = CocoDetection(image_root_dir=train_dataset_path,
                                  annotation_root_dir=dataset_annotations_path,
                                  set="train2017",
                                  transform=transforms.Compose([
                                      RandomFlip(flip_prob=0.5),
                                      RandomCrop(crop_prob=0.5),
                                      RandomTranslate(translate_prob=0.5),
                                      Resize(resize=input_image_size),
                                  ]))
    val_dataset = CocoDetection(image_root_dir=val_dataset_path,
                                annotation_root_dir=dataset_annotations_path,
                                set="val2017",
                                transform=transforms.Compose([
                                    Resize(resize=input_image_size),
                                ]))

    epochs = 12
    batch_size = 16
    lr = 1e-5
    num_workers = 4
    print_interval = 100
    apex = True

train.py文件:

import sys
import os
import argparse
import random
import shutil
import time
import warnings
import json

BASE_DIR = os.path.dirname(
    os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(BASE_DIR)
warnings.filterwarnings('ignore')

import numpy as np
from thop import profile
from thop import clever_format
from apex import amp
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torchvision import transforms
from config import Config
from public.detection.dataset.cocodataset import COCODataPrefetcher, collater
from public.detection.models.loss import RetinaLoss
from public.detection.models.decode import RetinaDecoder
from public.detection.models.retinanet import resnet50_retinanet
from public.imagenet.utils import get_logger
from pycocotools.cocoeval import COCOeval


def parse_args():
    parser = argparse.ArgumentParser(
        description='PyTorch COCO Detection Training')
    parser.add_argument('--network',
                        type=str,
                        default=Config.network,
                        help='name of network')
    parser.add_argument('--lr',
                        type=float,
                        default=Config.lr,
                        help='learning rate')
    parser.add_argument('--epochs',
                        type=int,
                        default=Config.epochs,
                        help='num of training epochs')
    parser.add_argument('--batch_size',
                        type=int,
                        default=Config.batch_size,
                        help='batch size')
    parser.add_argument('--pretrained',
                        type=bool,
                        default=Config.pretrained,
                        help='load pretrained model params or not')
    parser.add_argument('--num_classes',
                        type=int,
                        default=Config.num_classes,
                        help='model classification num')
    parser.add_argument('--input_image_size',
                        type=int,
                        default=Config.input_image_size,
                        help='input image size')
    parser.add_argument('--num_workers',
                        type=int,
                        default=Config.num_workers,
                        help='number of worker to load data')
    parser.add_argument('--resume',
                        type=str,
                        default=Config.resume,
                        help='put the path to resuming file if needed')
    parser.add_argument('--checkpoints',
                        type=str,
                        default=Config.checkpoint_path,
                        help='path for saving trained models')
    parser.add_argument('--log',
                        type=str,
                        default=Config.log,
                        help='path to save log')
    parser.add_argument('--evaluate',
                        type=str,
                        default=Config.evaluate,
                        help='path for evaluate model')
    parser.add_argument('--seed', type=int, default=Config.seed, help='seed')
    parser.add_argument('--print_interval',
                        type=bool,
                        default=Config.print_interval,
                        help='print interval')
    parser.add_argument('--apex',
                        type=bool,
                        default=Config.apex,
                        help='use apex or not')

    return parser.parse_args()


def validate(val_dataset, model, decoder):
    model = model.module
    # switch to evaluate mode
    model.eval()
    with torch.no_grad():
        all_eval_result = evaluate_coco(val_dataset, model, decoder)

    return all_eval_result


def evaluate_coco(val_dataset, model, decoder):
    results, image_ids = [], []
    for index in range(len(val_dataset)):
        data = val_dataset[index]
        scale = data['scale']
        cls_heads, reg_heads, batch_anchors = model(data['img'].cuda().permute(
            2, 0, 1).float().unsqueeze(dim=0))
        scores, classes, boxes = decoder(cls_heads, reg_heads, batch_anchors)
        scores, classes, boxes = scores.cpu(), classes.cpu(), boxes.cpu()
        boxes /= scale

        # make sure decode batch_size=1
        # scores shape:[1,max_detection_num]
        # classes shape:[1,max_detection_num]
        # bboxes shape[1,max_detection_num,4]
        assert scores.shape[0] == 1

        scores = scores.squeeze(0)
        classes = classes.squeeze(0)
        boxes = boxes.squeeze(0)

        # for coco_eval,we need [x_min,y_min,w,h] format pred boxes
        boxes[:, 2:] -= boxes[:, :2]

        for object_score, object_class, object_box in zip(
                scores, classes, boxes):
            object_score = float(object_score)
            object_class = int(object_class)
            object_box = object_box.tolist()
            if object_class == -1:
                break

            image_result = {
                'image_id':
                val_dataset.image_ids[index],
                'category_id':
                val_dataset.find_category_id_from_coco_label(object_class),
                'score':
                object_score,
                'bbox':
                object_box,
            }
            results.append(image_result)

        image_ids.append(val_dataset.image_ids[index])

        print('{}/{}'.format(index, len(val_dataset)), end='\r')

    if not len(results):
        print("No target detected in test set images")
        return

    json.dump(results,
              open('{}_bbox_results.json'.format(val_dataset.set_name), 'w'),
              indent=4)

    # load results in COCO evaluation tool
    coco_true = val_dataset.coco
    coco_pred = coco_true.loadRes('{}_bbox_results.json'.format(
        val_dataset.set_name))

    coco_eval = COCOeval(coco_true, coco_pred, 'bbox')
    coco_eval.params.imgIds = image_ids
    coco_eval.evaluate()
    coco_eval.accumulate()
    coco_eval.summarize()
    all_eval_result = coco_eval.stats

    return all_eval_result


def train(train_loader, model, criterion, optimizer, scheduler, epoch, logger,
          args):
    cls_losses, reg_losses, losses = [], [], []

    # switch to train mode
    model.train()

    iters = len(train_loader.dataset) // args.batch_size
    prefetcher = COCODataPrefetcher(train_loader)
    images, annotations = prefetcher.next()
    iter_index = 1

    while images is not None:
        images, annotations = images.cuda().float(), annotations.cuda()
        cls_heads, reg_heads, batch_anchors = model(images)
        cls_loss, reg_loss = criterion(cls_heads, reg_heads, batch_anchors,
                                       annotations)
        loss = cls_loss + reg_loss
        if cls_loss == 0.0 or reg_loss == 0.0:
            optimizer.zero_grad()
            continue

        if args.apex:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()
        optimizer.zero_grad()

        cls_losses.append(cls_loss.item())
        reg_losses.append(reg_loss.item())
        losses.append(loss.item())

        images, annotations = prefetcher.next()

        if iter_index % args.print_interval == 0:
            logger.info(
                f"train: epoch {epoch:0>3d}, iter [{iter_index:0>5d}, {iters:0>5d}], cls_loss: {cls_loss.item():.2f}, reg_loss: {reg_loss.item():.2f}, loss_total: {loss.item():.2f}"
            )

        iter_index += 1

    scheduler.step(np.mean(losses))

    return np.mean(cls_losses), np.mean(reg_losses), np.mean(losses)


def main(logger, args):
    if not torch.cuda.is_available():
        raise Exception("need gpu to train network!")

    torch.cuda.empty_cache()

    if args.seed is not None:
        random.seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        cudnn.deterministic = True

    gpus = torch.cuda.device_count()
    logger.info(f'use {gpus} gpus')
    logger.info(f"args: {args}")

    cudnn.benchmark = True
    cudnn.enabled = True
    start_time = time.time()

    # dataset and dataloader
    logger.info('start loading data')
    train_loader = DataLoader(Config.train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers,
                              collate_fn=collater)
    logger.info('finish loading data')

    model = resnet50_retinanet(**{
        "pretrained": args.pretrained,
        "num_classes": args.num_classes,
    })

    for name, param in model.named_parameters():
        logger.info(f"{name},{param.requires_grad}")

    flops_input = torch.randn(1, 3, args.input_image_size,
                              args.input_image_size)
    flops, params = profile(model, inputs=(flops_input, ))
    flops, params = clever_format([flops, params], "%.3f")
    logger.info(f"model: '{args.network}', flops: {flops}, params: {params}")

    criterion = RetinaLoss(image_w=args.input_image_size,
                           image_h=args.input_image_size).cuda()
    decoder = RetinaDecoder(image_w=args.input_image_size,
                            image_h=args.input_image_size).cuda()

    model = model.cuda()
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           patience=3,
                                                           verbose=True)

    if args.apex:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')

    model = nn.DataParallel(model)

    if args.evaluate:
        if not os.path.isfile(args.evaluate):
            raise Exception(
                f"{args.resume} is not a file, please check it again")
        logger.info('start only evaluating')
        logger.info(f"start resuming model from {args.evaluate}")
        checkpoint = torch.load(args.evaluate,
                                map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint['model_state_dict'])
        all_eval_result = validate(Config.val_dataset, model, decoder)
        if all_eval_result is not None:
            logger.info(
                f"val: epoch: {checkpoint['epoch']:0>5d}, IoU=0.5:0.95,area=all,maxDets=100,mAP:{all_eval_result[0]:.3f}, IoU=0.5,area=all,maxDets=100,mAP:{all_eval_result[1]:.3f}, IoU=0.75,area=all,maxDets=100,mAP:{all_eval_result[2]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAP:{all_eval_result[3]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAP:{all_eval_result[4]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAP:{all_eval_result[5]:.3f}, IoU=0.5:0.95,area=all,maxDets=1,mAR:{all_eval_result[6]:.3f}, IoU=0.5:0.95,area=all,maxDets=10,mAR:{all_eval_result[7]:.3f}, IoU=0.5:0.95,area=all,maxDets=100,mAR:{all_eval_result[8]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAR:{all_eval_result[9]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAR:{all_eval_result[10]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAR:{all_eval_result[11]:.3f}"
            )

        return

    best_map = 0.0
    start_epoch = 1
    # resume training
    if os.path.exists(args.resume):
        logger.info(f"start resuming model from {args.resume}")
        checkpoint = torch.load(args.resume, map_location=torch.device('cpu'))
        start_epoch += checkpoint['epoch']
        best_map = checkpoint['best_map']
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        logger.info(
            f"finish resuming model from {args.resume}, epoch {checkpoint['epoch']}, best_map: {checkpoint['best_map']}, "
            f"loss: {checkpoint['loss']:3f}, cls_loss: {checkpoint['cls_loss']:2f}, reg_loss: {checkpoint['reg_loss']:2f}"
        )

    if not os.path.exists(args.checkpoints):
        os.makedirs(args.checkpoints)

    logger.info('start training')
    for epoch in range(start_epoch, args.epochs + 1):
        cls_losses, reg_losses, losses = train(train_loader, model, criterion,
                                               optimizer, scheduler, epoch,
                                               logger, args)
        logger.info(
            f"train: epoch {epoch:0>3d}, cls_loss: {cls_losses:.2f}, reg_loss: {reg_losses:.2f}, loss: {losses:.2f}"
        )

        if epoch % 5 == 0 or epoch == args.epochs:
            all_eval_result = validate(Config.val_dataset, model, args,
                                       decoder)
            logger.info(f"eval done.")
            if all_eval_result is not None:
                logger.info(
                    f"val: epoch: {epoch:0>5d}, IoU=0.5:0.95,area=all,maxDets=100,mAP:{all_eval_result[0]:.3f}, IoU=0.5,area=all,maxDets=100,mAP:{all_eval_result[1]:.3f}, IoU=0.75,area=all,maxDets=100,mAP:{all_eval_result[2]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAP:{all_eval_result[3]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAP:{all_eval_result[4]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAP:{all_eval_result[5]:.3f}, IoU=0.5:0.95,area=all,maxDets=1,mAR:{all_eval_result[6]:.3f}, IoU=0.5:0.95,area=all,maxDets=10,mAR:{all_eval_result[7]:.3f}, IoU=0.5:0.95,area=all,maxDets=100,mAR:{all_eval_result[8]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAR:{all_eval_result[9]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAR:{all_eval_result[10]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAR:{all_eval_result[11]:.3f}"
                )
                if all_eval_result[0] > best_map:
                    torch.save(model.module.state_dict(),
                               os.path.join(args.checkpoints, "best.pth"))
                    best_map = all_eval_result[0]
        torch.save(
            {
                'epoch': epoch,
                'best_map': best_map,
                'cls_loss': cls_losses,
                'reg_loss': reg_losses,
                'loss': losses,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
            }, os.path.join(args.checkpoints, 'latest.pth'))

    logger.info(f"finish training, best_map: {best_map:.3f}")
    training_time = (time.time() - start_time) / 3600
    logger.info(
        f"finish training, total training time: {training_time:.2f} hours")


if __name__ == '__main__':
    args = parse_args()
    logger = get_logger(__name__, args.log)
    main(logger, args)

上面实现的是在nn.DataParallel模式下的训练,分布式训练方法我会在下一篇文章中实现。要想进行训练,只需python train.py即可。

模型复现情况评估

根据六篇文章在各个方面对RetinaNet的复现方法,目前要与论文中RetinaNet模型点数对上还有三个问题:

  1. Detectron和Detectron2中使用的ResNet50在ImageNet上的预训练模型参数是他们自己训练的,该预训练模型参数可能要比我的ResNet50预训练模型参数要好(我的ResNet50与训练模型表现为top1-error=23.488%)。根据以往经验,预训练模型的表现越好,finetune后的结果就会更好(两者不是线性关系,但是是正相关关系)。
  2. 上面的训练使用的是nn.parallel模式,这个模式下不能使用跨卡同步BN,而Detectron和Detectron2中均使用分布式训练+跨卡同步BN。在没有同步BN的情况下,BN只能根据单张卡上batchsize的数据更新均值和标准差,如果单卡上的batchsize小于论文中的16的话,会造成BN训练的没有使用跨卡同步BN时准,模型表现会因此有所下降。
  3. 由于我没有看完Detectron和Detectron2中的全部代码,这两个框架中可能还有我未发现的有助于训练的改进手段。

对于问题1,目前无法去验证。对于问题2,我们将在下一章使用分布式训练+跨卡同步BN训练解决。对于问题3,我现在也没有精力去看完Detectron和Detectron2的全部代码,欢迎小伙伴们提出指正。

模型在COCO数据集上的性能表现如下(batchsize为16,输入分辨率为600,约等于RetinaNet论文中的分辨率450):(正在训练中,这两天会更新)

Network epoch5-mAP epoch5-mAP epoch12-mAP
ResNet50-RetinaNet-only-flip
ResNet50-RetinaNet-flip-crop-translate
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章