base model第六彈:在ImageNet上訓練RegNet

RegNet網絡結構

根據論文和官方開源代碼,我用Pytorch重新構建了RegNet網絡(論文:https://arxiv.org/pdf/2003.13678.pdf 。開源代碼:https://github.com/facebookresearch/pycls)。
網絡結構如下:

"""
Designing Network Design Spaces
https://arxiv.org/pdf/2003.13678.pdf
"""
import numpy as np

import torch
import torch.nn as nn

root_path = '/home/zhugechaoran/Downloads/pytorch_imagenet_model/public/pretrained_models'

__all__ = [
    'RegNetY_200MF',
    'RegNetY_400MF',
    'RegNetY_600MF',
    'RegNetY_800MF',
    'RegNetY_1_6GF',
    'RegNetY_3_2GF',
    'RegNetY_4_0GF',
    'RegNetY_6_4GF',
    'RegNetY_8_0GF',
    'RegNetY_12GF',
    'RegNetY_16GF',
    'RegNetY_32GF',
]

model_urls = {
    'RegNetY_200MF':
    '{}/regnet/RegNetY_200MF-epoch100-acc70.096.pth'.format(root_path),
    'RegNetY_400MF':
    '{}/regnet/RegNetY_400MF-epoch100-acc73.79.pth'.format(root_path),
    'RegNetY_600MF':
    '{}/regnet/RegNetY_600MF-epoch100-acc74.724.pth'.format(root_path),
    'RegNetY_800MF':
    '{}/regnet/RegNetY_800MF-epoch100-acc75.994.pth'.format(root_path),
    'RegNetY_1_6GF':
    '{}/regnet/RegNetY_1_6GF-epoch100-acc77.308.pth'.format(root_path),
    'RegNetY_3_2GF':
    '{}/regnet/RegNetY_3_2GF-epoch100-acc78.908.pth'.format(root_path),
    'RegNetY_4_0GF':
    '{}/regnet/RegNetY_4_0GF-epoch100-acc78.136.pth'.format(root_path),
    'RegNetY_6_4GF':
    '{}/regnet/RegNetY_6_4GF-epoch100-acc78.77.pth'.format(root_path),
    'RegNetY_8_0GF':
    'empty',
    'RegNetY_12GF':
    'empty',
    'RegNetY_16GF':
    'empty',
    'RegNetY_32GF':
    'empty',
}

RegnetY_types_config = {
    'RegNetY_200MF': {
        "stem_width": 32,
        "w_a": 36.44,
        "w_0": 24,
        "w_m": 2.49,
        "groups": 8,
        "depth": 13,
        "has_se": True,
    },
    'RegNetY_400MF': {
        "stem_width": 32,
        "w_a": 27.89,
        "w_0": 48,
        "w_m": 2.09,
        "groups": 8,
        "depth": 16,
        "has_se": True,
    },
    'RegNetY_600MF': {
        "stem_width": 32,
        "w_a": 32.54,
        "w_0": 48,
        "w_m": 2.32,
        "groups": 16,
        "depth": 15,
        "has_se": True,
    },
    'RegNetY_800MF': {
        "stem_width": 32,
        "w_a": 38.84,
        "w_0": 56,
        "w_m": 2.4,
        "groups": 16,
        "depth": 14,
        "has_se": True,
    },
    'RegNetY_1_6GF': {
        "stem_width": 32,
        "w_a": 20.71,
        "w_0": 48,
        "w_m": 2.65,
        "groups": 24,
        "depth": 27,
        "has_se": True,
    },
    'RegNetY_3_2GF': {
        "stem_width": 32,
        "w_a": 42.63,
        "w_0": 80,
        "w_m": 2.66,
        "groups": 24,
        "depth": 21,
        "has_se": True,
    },
    'RegNetY_4_0GF': {
        "stem_width": 32,
        "w_a": 31.41,
        "w_0": 96,
        "w_m": 2.24,
        "groups": 64,
        "depth": 22,
        "has_se": True,
    },
    'RegNetY_6_4GF': {
        "stem_width": 32,
        "w_a": 33.22,
        "w_0": 112,
        "w_m": 2.27,
        "groups": 72,
        "depth": 25,
        "has_se": True,
    },
    'RegNetY_8_0GF': {
        "stem_width": 32,
        "w_a": 76.82,
        "w_0": 192,
        "w_m": 2.19,
        "groups": 56,
        "depth": 17,
        "has_se": True,
    },
    'RegNetY_12GF': {
        "stem_width": 32,
        "w_a": 73.36,
        "w_0": 168,
        "w_m": 2.37,
        "groups": 112,
        "depth": 19,
        "has_se": True,
    },
    'RegNetY_16GF': {
        "stem_width": 32,
        "w_a": 106.23,
        "w_0": 200,
        "w_m": 2.48,
        "groups": 112,
        "depth": 18,
        "has_se": True,
    },
    'RegNetY_32GF': {
        "stem_width": 32,
        "w_a": 115.89,
        "w_0": 232,
        "w_m": 2.53,
        "groups": 232,
        "depth": 20,
        "has_se": True,
    },
}


def get_regnet_config(regnet_type, q=8):
    stem_width, has_se = regnet_type["stem_width"], regnet_type["has_se"]

    w_a, w_0, w_m, depth, groups = regnet_type["w_a"], regnet_type[
        "w_0"], regnet_type["w_m"], regnet_type["depth"], regnet_type["groups"]

    assert w_a >= 0 and w_0 > 0 and w_m > 1 and w_0 % q == 0
    ks = np.round(np.log((np.arange(depth) * w_a + w_0) / w_0) / np.log(w_m))
    per_stage_width = w_0 * np.power(w_m, ks)
    per_stage_width = (np.round(np.divide(per_stage_width, q)) *
                       q).astype(int).tolist()

    ts_temp = zip(per_stage_width + [0], [0] + per_stage_width,
                  per_stage_width + [0], [0] + per_stage_width)
    ts = [w != wp or r != rp for w, wp, r, rp in ts_temp]
    per_stage_depth = np.diff([d for d, t in zip(range(len(ts)), ts)
                               if t]).tolist()

    per_stage_width = np.unique(per_stage_width).tolist()

    per_stage_groups = [groups for _ in range(len(per_stage_width))]
    per_stage_groups = [
        min(per_g, per_w)
        for per_g, per_w in zip(per_stage_groups, per_stage_width)
    ]
    per_stage_width = [
        int(round(per_w / per_g) * per_g)
        for per_w, per_g in zip(per_stage_width, per_stage_groups)
    ]

    return stem_width, has_se, per_stage_width, per_stage_depth, per_stage_groups


class ConvBnActBlock(nn.Module):
    def __init__(self,
                 inplanes,
                 planes,
                 kernel_size,
                 stride,
                 padding,
                 groups=1,
                 has_bn=True,
                 has_act=True):
        super().__init__()
        self.has_bn = has_bn
        self.has_act = has_act
        self.conv = nn.Conv2d(inplanes,
                              planes,
                              kernel_size,
                              stride=stride,
                              padding=padding,
                              groups=groups,
                              bias=False)
        if self.has_bn:
            self.bn = nn.BatchNorm2d(planes)
        if self.has_act:
            self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        if self.has_bn:
            x = self.bn(x)
        if self.has_act:
            x = self.act(x)

        return x


class SeBlock(nn.Module):
    def __init__(self, inplanes, reduction=4):
        super(SeBlock, self).__init__()
        squeezed_planes = max(1, int(inplanes // reduction))
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.conv1 = nn.Conv2d(inplanes,
                               squeezed_planes,
                               kernel_size=1,
                               stride=1,
                               padding=0,
                               bias=True)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(squeezed_planes,
                               inplanes,
                               kernel_size=1,
                               stride=1,
                               padding=0,
                               bias=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, inputs):
        x = self.avgpool(inputs)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.sigmoid(x)
        x = inputs * x

        return x


class XBlock(nn.Module):
    def __init__(self,
                 inplanes,
                 planes,
                 stride,
                 groups,
                 neck_ratio=1,
                 has_se=True,
                 downsample=False):
        super(XBlock, self).__init__()
        self.has_se = has_se
        self.downsample = downsample

        if self.downsample:
            self.downsample_layer = ConvBnActBlock(inplanes,
                                                   planes,
                                                   kernel_size=1,
                                                   stride=stride,
                                                   padding=0,
                                                   groups=1,
                                                   has_bn=True,
                                                   has_act=False)
        self.conv1 = ConvBnActBlock(inplanes,
                                    planes,
                                    kernel_size=1,
                                    stride=1,
                                    padding=0,
                                    groups=1,
                                    has_bn=True,
                                    has_act=True)
        self.conv2 = ConvBnActBlock(planes,
                                    planes // neck_ratio,
                                    kernel_size=3,
                                    stride=stride,
                                    padding=1,
                                    groups=groups,
                                    has_bn=True,
                                    has_act=True)
        self.conv3 = ConvBnActBlock(planes // neck_ratio,
                                    planes,
                                    kernel_size=1,
                                    stride=1,
                                    padding=0,
                                    groups=1,
                                    has_bn=True,
                                    has_act=False)
        self.relu = nn.ReLU(inplace=True)

        if self.has_se:
            self.se_block = SeBlock(planes)

    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.conv2(x)
        x = self.conv3(x)
        if self.has_se:
            x = self.se_block(x)

        if self.downsample:
            inputs = self.downsample_layer(inputs)

        x += inputs
        x = self.relu(x)
        return x


class RegNet(nn.Module):
    def __init__(self, regnet_type, num_classes=1000):
        super(RegNet, self).__init__()
        stem_width, has_se, per_stage_width, per_stage_depth, per_stage_groups = get_regnet_config(
            RegnetY_types_config[regnet_type], q=8)

        self.stem_width = stem_width
        self.has_se = has_se
        self.per_stage_width = per_stage_width
        self.per_stage_depth = per_stage_depth
        self.per_stage_groups = per_stage_groups

        assert len(self.per_stage_width) == len(self.per_stage_depth)
        assert len(self.per_stage_depth) == len(self.per_stage_groups)

        self.conv1 = ConvBnActBlock(3,
                                    self.stem_width,
                                    kernel_size=3,
                                    stride=2,
                                    padding=1,
                                    groups=1,
                                    has_bn=True,
                                    has_act=True)

        self.layer1 = self.make_layer(self.stem_width,
                                      self.per_stage_width[0],
                                      stride=2,
                                      block_num=self.per_stage_depth[0],
                                      group_num=self.per_stage_groups[0],
                                      has_se=self.has_se)
        self.layer2 = self.make_layer(self.per_stage_width[0],
                                      self.per_stage_width[1],
                                      stride=2,
                                      block_num=self.per_stage_depth[1],
                                      group_num=self.per_stage_groups[1],
                                      has_se=self.has_se)
        self.layer3 = self.make_layer(self.per_stage_width[1],
                                      self.per_stage_width[2],
                                      stride=2,
                                      block_num=self.per_stage_depth[2],
                                      group_num=self.per_stage_groups[2],
                                      has_se=self.has_se)
        self.layer4 = self.make_layer(self.per_stage_width[2],
                                      self.per_stage_width[3],
                                      stride=2,
                                      block_num=self.per_stage_depth[3],
                                      group_num=self.per_stage_groups[3],
                                      has_se=self.has_se)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(self.per_stage_width[3], num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight,
                                        mode='fan_out',
                                        nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def make_layer(self, inplanes, planes, stride, block_num, group_num,
                   has_se):
        layers = []
        for block_index in range(block_num):
            downsample = True if block_index == 0 and (
                stride != 1 or inplanes != planes) else False
            inplanes = planes if block_index > 0 else inplanes
            stride = 1 if block_index > 0 else stride
            layers.append(
                XBlock(inplanes,
                       planes,
                       stride=stride,
                       groups=group_num,
                       has_se=has_se,
                       downsample=downsample))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x


def _regnet(arch, pretrained, progress, **kwargs):
    model = RegNet(arch, **kwargs)
    # only load state_dict()
    if pretrained:
        model.load_state_dict(
            torch.load(model_urls[arch], map_location=torch.device('cpu')))

    return model


def RegNetY_200MF(pretrained=False, progress=True, **kwargs):
    return _regnet('RegNetY_200MF', pretrained, progress, **kwargs)


def RegNetY_400MF(pretrained=False, progress=True, **kwargs):
    return _regnet('RegNetY_400MF', pretrained, progress, **kwargs)


def RegNetY_600MF(pretrained=False, progress=True, **kwargs):
    return _regnet('RegNetY_600MF', pretrained, progress, **kwargs)


def RegNetY_800MF(pretrained=False, progress=True, **kwargs):
    return _regnet('RegNetY_800MF', pretrained, progress, **kwargs)


def RegNetY_1_6GF(pretrained=False, progress=True, **kwargs):
    return _regnet('RegNetY_1_6GF', pretrained, progress, **kwargs)


def RegNetY_3_2GF(pretrained=False, progress=True, **kwargs):
    return _regnet('RegNetY_3_2GF', pretrained, progress, **kwargs)


def RegNetY_4_0GF(pretrained=False, progress=True, **kwargs):
    return _regnet('RegNetY_4_0GF', pretrained, progress, **kwargs)


def RegNetY_6_4GF(pretrained=False, progress=True, **kwargs):
    return _regnet('RegNetY_6_4GF', pretrained, progress, **kwargs)


def RegNetY_8_0GF(pretrained=False, progress=True, **kwargs):
    return _regnet('RegNetY_8_0GF', pretrained, progress, **kwargs)


def RegNetY_12GF(pretrained=False, progress=True, **kwargs):
    return _regnet('RegNetY_12GF', pretrained, progress, **kwargs)


def RegNetY_16GF(pretrained=False, progress=True, **kwargs):
    return _regnet('RegNetY_16GF', pretrained, progress, **kwargs)


def RegNetY_32GF(pretrained=False, progress=True, **kwargs):
    return _regnet('RegNetY_32GF', pretrained, progress, **kwargs)

RegNet的訓練

對於RegNet,我們基本上遵循官方的訓練超參數:
config.py文件如下:

import torchvision.transforms as transforms
import torchvision.datasets as datasets


class Config(object):
    log = './log'  # Path to save log
    checkpoint_path = './checkpoints'  # Path to store checkpoint model
    resume = './checkpoints/latest.pth'  # load checkpoint model
    evaluate = None  # evaluate model path
    train_dataset_path = '/data/ILSVRC2012/train/'
    val_dataset_path = '/data/ILSVRC2012/val/'

    network = "RegNetY_800MF"
    pretrained = False
    num_classes = 1000
    seed = 0
    input_image_size = 224
    scale = 256 / 224

    train_dataset = datasets.ImageFolder(
        train_dataset_path,
        transforms.Compose([
            transforms.RandomResizedCrop(input_image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.4,
                                   contrast=0.4,
                                   saturation=0.4,
                                   hue=0),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ]))
    val_dataset = datasets.ImageFolder(
        val_dataset_path,
        transforms.Compose([
            transforms.Resize(int(input_image_size * scale)),
            transforms.CenterCrop(input_image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ]))

    epochs = 100
    warm_up_epochs = 5
    batch_size = 256
    accumulation_steps = 1
    lr = 0.2
    weight_decay = 5e-5
    momentum = 0.9
    num_workers = 8
    print_interval = 100
    apex = False

train.py文件如下:

import sys
import os
import argparse
import math
import random
import time
import warnings

BASE_DIR = os.path.dirname(
    os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(BASE_DIR)
warnings.filterwarnings('ignore')

from apex import amp
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
from thop import profile
from thop import clever_format
from torch.utils.data import DataLoader
from config import Config
from public.imagenet import models
from public.imagenet.utils import DataPrefetcher, get_logger, AverageMeter, accuracy


def parse_args():
    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
    parser.add_argument('--network',
                        type=str,
                        default=Config.network,
                        help='name of network')
    parser.add_argument('--lr',
                        type=float,
                        default=Config.lr,
                        help='learning rate')
    parser.add_argument('--momentum',
                        type=float,
                        default=Config.momentum,
                        help='momentum')
    parser.add_argument('--weight_decay',
                        type=float,
                        default=Config.weight_decay,
                        help='weight decay')
    parser.add_argument('--epochs',
                        type=int,
                        default=Config.epochs,
                        help='num of training epochs')
    parser.add_argument('--warm_up_epochs',
                        type=int,
                        default=Config.warm_up_epochs,
                        help='num of warm up epochs')
    parser.add_argument('--batch_size',
                        type=int,
                        default=Config.batch_size,
                        help='batch size')
    parser.add_argument('--accumulation_steps',
                        type=int,
                        default=Config.accumulation_steps,
                        help='gradient accumulation steps')
    parser.add_argument('--pretrained',
                        type=bool,
                        default=Config.pretrained,
                        help='load pretrained model params or not')
    parser.add_argument('--num_classes',
                        type=int,
                        default=Config.num_classes,
                        help='model classification num')
    parser.add_argument('--input_image_size',
                        type=int,
                        default=Config.input_image_size,
                        help='input image size')
    parser.add_argument('--num_workers',
                        type=int,
                        default=Config.num_workers,
                        help='number of worker to load data')
    parser.add_argument('--resume',
                        type=str,
                        default=Config.resume,
                        help='put the path to resuming file if needed')
    parser.add_argument('--checkpoints',
                        type=str,
                        default=Config.checkpoint_path,
                        help='path for saving trained models')
    parser.add_argument('--log',
                        type=str,
                        default=Config.log,
                        help='path to save log')
    parser.add_argument('--evaluate',
                        type=str,
                        default=Config.evaluate,
                        help='path for evaluate model')
    parser.add_argument('--seed', type=int, default=Config.seed, help='seed')
    parser.add_argument('--print_interval',
                        type=bool,
                        default=Config.print_interval,
                        help='print interval')
    parser.add_argument('--apex',
                        type=bool,
                        default=Config.apex,
                        help='use apex or not')

    return parser.parse_args()


def train(train_loader, model, criterion, optimizer, scheduler, epoch, logger,
          args):
    top1 = AverageMeter()
    top5 = AverageMeter()
    losses = AverageMeter()

    # switch to train mode
    model.train()

    iters = len(train_loader.dataset) // args.batch_size
    prefetcher = DataPrefetcher(train_loader)
    inputs, labels = prefetcher.next()
    iter_index = 1

    while inputs is not None:
        inputs, labels = inputs.cuda(), labels.cuda()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss = loss / args.accumulation_steps

        if args.apex:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        if iter_index % args.accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        # measure accuracy and record loss
        acc1, acc5 = accuracy(outputs, labels, topk=(1, 5))
        top1.update(acc1.item(), inputs.size(0))
        top5.update(acc5.item(), inputs.size(0))
        losses.update(loss.item(), inputs.size(0))

        inputs, labels = prefetcher.next()

        if iter_index % args.print_interval == 0:
            logger.info(
                f"train: epoch {epoch:0>3d}, iter [{iter_index:0>4d}, {iters:0>4d}], lr: {scheduler.get_lr()[0]:.6f}, top1 acc: {acc1.item():.2f}%, top5 acc: {acc5.item():.2f}%, loss_total: {loss.item():.2f}"
            )

        iter_index += 1

    scheduler.step()

    return top1.avg, top5.avg, losses.avg


def validate(val_loader, model, args):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for inputs, labels in val_loader:
            data_time.update(time.time() - end)
            inputs, labels = inputs.cuda(), labels.cuda()
            outputs = model(inputs)
            acc1, acc5 = accuracy(outputs, labels, topk=(1, 5))
            top1.update(acc1.item(), inputs.size(0))
            top5.update(acc5.item(), inputs.size(0))
            batch_time.update(time.time() - end)
            end = time.time()

    throughput = 1.0 / (batch_time.avg / inputs.size(0))

    return top1.avg, top5.avg, throughput


def main(logger, args):
    if not torch.cuda.is_available():
        raise Exception("need gpu to train network!")

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        cudnn.deterministic = True

    gpus = torch.cuda.device_count()
    logger.info(f'use {gpus} gpus')
    logger.info(f"args: {args}")

    cudnn.benchmark = True
    cudnn.enabled = True
    start_time = time.time()

    # dataset and dataloader
    logger.info('start loading data')
    train_loader = DataLoader(Config.train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              pin_memory=True,
                              num_workers=args.num_workers)
    val_loader = DataLoader(Config.val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            pin_memory=True,
                            num_workers=args.num_workers)
    logger.info('finish loading data')

    logger.info(f"creating model '{args.network}'")
    model = models.__dict__[args.network](**{
        "pretrained": args.pretrained,
        "num_classes": args.num_classes,
    })

    flops_input = torch.randn(1, 3, args.input_image_size,
                              args.input_image_size)
    flops, params = profile(model, inputs=(flops_input, ))
    flops, params = clever_format([flops, params], "%.3f")
    logger.info(f"model: '{args.network}', flops: {flops}, params: {params}")

    for name, param in model.named_parameters():
        logger.info(f"{name},{param.requires_grad}")

    model = model.cuda()
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # warm_up_with_cosine_lr
    warm_up_with_cosine_lr = lambda epoch: epoch / args.warm_up_epochs if epoch <= args.warm_up_epochs else 0.5 * (
        math.cos((epoch - args.warm_up_epochs) /
                 (args.epochs - args.warm_up_epochs) * math.pi) + 1)
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer, lr_lambda=warm_up_with_cosine_lr)

    if args.apex:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')

    model = nn.DataParallel(model)

    if args.evaluate:
        if not os.path.isfile(args.evaluate):
            raise Exception(
                f"{args.resume} is not a file, please check it again")
        logger.info('start only evaluating')
        logger.info(f"start resuming model from {args.evaluate}")
        checkpoint = torch.load(args.evaluate,
                                map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint['model_state_dict'])
        acc1, acc5, throughput = validate(val_loader, model, args)
        logger.info(
            f"epoch {checkpoint['epoch']:0>3d}, top1 acc: {acc1:.2f}%, top5 acc: {acc5:.2f}%, throughput: {throughput:.2f}sample/s"
        )

        return

    start_epoch = 1
    # resume training
    if os.path.exists(args.resume):
        logger.info(f"start resuming model from {args.resume}")
        checkpoint = torch.load(args.resume, map_location=torch.device('cpu'))
        start_epoch += checkpoint['epoch']
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        logger.info(
            f"finish resuming model from {args.resume}, epoch {checkpoint['epoch']}, "
            f"loss: {checkpoint['loss']:3f}, lr: {checkpoint['lr']:.6f}, "
            f"top1_acc: {checkpoint['acc1']}%")

    if not os.path.exists(args.checkpoints):
        os.makedirs(args.checkpoints)

    logger.info('start training')
    for epoch in range(start_epoch, args.epochs + 1):
        acc1, acc5, losses = train(train_loader, model, criterion, optimizer,
                                   scheduler, epoch, logger, args)
        logger.info(
            f"train: epoch {epoch:0>3d}, top1 acc: {acc1:.2f}%, top5 acc: {acc5:.2f}%, losses: {losses:.2f}"
        )

        acc1, acc5, throughput = validate(val_loader, model, args)
        logger.info(
            f"val: epoch {epoch:0>3d}, top1 acc: {acc1:.2f}%, top5 acc: {acc5:.2f}%, throughput: {throughput:.2f}sample/s"
        )

        # remember best prec@1 and save checkpoint
        torch.save(
            {
                'epoch': epoch,
                'acc1': acc1,
                'loss': losses,
                'lr': scheduler.get_lr()[0],
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
            }, os.path.join(args.checkpoints, 'latest.pth'))
        if epoch == args.epochs:
            torch.save(
                model.module.state_dict(),
                os.path.join(
                    args.checkpoints,
                    "{}-epoch{}-acc{}.pth".format(args.network, epoch, acc1)))

    training_time = (time.time() - start_time) / 3600
    logger.info(
        f"finish training, total training time: {training_time:.2f} hours")


if __name__ == '__main__':
    args = parse_args()
    logger = get_logger(__name__, args.log)
    main(logger, args)

對RegNet訓練時,我們總共訓練100個epoch,在前5個epoch使用warm up,並且使用cosine learning rate衰減。

RegNet訓練結果

Network Top-1 error
RegNetY_200MF 29.904
RegNetY_400MF 26.210
RegNetY_600MF 25.276
RegNetY_800MF 24.006
RegNetY_1.6GF 22.692
RegNetY_3.2GF 21.092
RegNetY_4.0GF 21.864
RegNetY_6.4GF 21.230
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章