RegNet網絡結構
根據論文和官方開源代碼,我用Pytorch重新構建了RegNet網絡(論文:https://arxiv.org/pdf/2003.13678.pdf 。開源代碼:https://github.com/facebookresearch/pycls)。
網絡結構如下:
"""
Designing Network Design Spaces
https://arxiv.org/pdf/2003.13678.pdf
"""
import numpy as np
import torch
import torch.nn as nn
root_path = '/home/zhugechaoran/Downloads/pytorch_imagenet_model/public/pretrained_models'
__all__ = [
'RegNetY_200MF',
'RegNetY_400MF',
'RegNetY_600MF',
'RegNetY_800MF',
'RegNetY_1_6GF',
'RegNetY_3_2GF',
'RegNetY_4_0GF',
'RegNetY_6_4GF',
'RegNetY_8_0GF',
'RegNetY_12GF',
'RegNetY_16GF',
'RegNetY_32GF',
]
model_urls = {
'RegNetY_200MF':
'{}/regnet/RegNetY_200MF-epoch100-acc70.096.pth'.format(root_path),
'RegNetY_400MF':
'{}/regnet/RegNetY_400MF-epoch100-acc73.79.pth'.format(root_path),
'RegNetY_600MF':
'{}/regnet/RegNetY_600MF-epoch100-acc74.724.pth'.format(root_path),
'RegNetY_800MF':
'{}/regnet/RegNetY_800MF-epoch100-acc75.994.pth'.format(root_path),
'RegNetY_1_6GF':
'{}/regnet/RegNetY_1_6GF-epoch100-acc77.308.pth'.format(root_path),
'RegNetY_3_2GF':
'{}/regnet/RegNetY_3_2GF-epoch100-acc78.908.pth'.format(root_path),
'RegNetY_4_0GF':
'{}/regnet/RegNetY_4_0GF-epoch100-acc78.136.pth'.format(root_path),
'RegNetY_6_4GF':
'{}/regnet/RegNetY_6_4GF-epoch100-acc78.77.pth'.format(root_path),
'RegNetY_8_0GF':
'empty',
'RegNetY_12GF':
'empty',
'RegNetY_16GF':
'empty',
'RegNetY_32GF':
'empty',
}
RegnetY_types_config = {
'RegNetY_200MF': {
"stem_width": 32,
"w_a": 36.44,
"w_0": 24,
"w_m": 2.49,
"groups": 8,
"depth": 13,
"has_se": True,
},
'RegNetY_400MF': {
"stem_width": 32,
"w_a": 27.89,
"w_0": 48,
"w_m": 2.09,
"groups": 8,
"depth": 16,
"has_se": True,
},
'RegNetY_600MF': {
"stem_width": 32,
"w_a": 32.54,
"w_0": 48,
"w_m": 2.32,
"groups": 16,
"depth": 15,
"has_se": True,
},
'RegNetY_800MF': {
"stem_width": 32,
"w_a": 38.84,
"w_0": 56,
"w_m": 2.4,
"groups": 16,
"depth": 14,
"has_se": True,
},
'RegNetY_1_6GF': {
"stem_width": 32,
"w_a": 20.71,
"w_0": 48,
"w_m": 2.65,
"groups": 24,
"depth": 27,
"has_se": True,
},
'RegNetY_3_2GF': {
"stem_width": 32,
"w_a": 42.63,
"w_0": 80,
"w_m": 2.66,
"groups": 24,
"depth": 21,
"has_se": True,
},
'RegNetY_4_0GF': {
"stem_width": 32,
"w_a": 31.41,
"w_0": 96,
"w_m": 2.24,
"groups": 64,
"depth": 22,
"has_se": True,
},
'RegNetY_6_4GF': {
"stem_width": 32,
"w_a": 33.22,
"w_0": 112,
"w_m": 2.27,
"groups": 72,
"depth": 25,
"has_se": True,
},
'RegNetY_8_0GF': {
"stem_width": 32,
"w_a": 76.82,
"w_0": 192,
"w_m": 2.19,
"groups": 56,
"depth": 17,
"has_se": True,
},
'RegNetY_12GF': {
"stem_width": 32,
"w_a": 73.36,
"w_0": 168,
"w_m": 2.37,
"groups": 112,
"depth": 19,
"has_se": True,
},
'RegNetY_16GF': {
"stem_width": 32,
"w_a": 106.23,
"w_0": 200,
"w_m": 2.48,
"groups": 112,
"depth": 18,
"has_se": True,
},
'RegNetY_32GF': {
"stem_width": 32,
"w_a": 115.89,
"w_0": 232,
"w_m": 2.53,
"groups": 232,
"depth": 20,
"has_se": True,
},
}
def get_regnet_config(regnet_type, q=8):
stem_width, has_se = regnet_type["stem_width"], regnet_type["has_se"]
w_a, w_0, w_m, depth, groups = regnet_type["w_a"], regnet_type[
"w_0"], regnet_type["w_m"], regnet_type["depth"], regnet_type["groups"]
assert w_a >= 0 and w_0 > 0 and w_m > 1 and w_0 % q == 0
ks = np.round(np.log((np.arange(depth) * w_a + w_0) / w_0) / np.log(w_m))
per_stage_width = w_0 * np.power(w_m, ks)
per_stage_width = (np.round(np.divide(per_stage_width, q)) *
q).astype(int).tolist()
ts_temp = zip(per_stage_width + [0], [0] + per_stage_width,
per_stage_width + [0], [0] + per_stage_width)
ts = [w != wp or r != rp for w, wp, r, rp in ts_temp]
per_stage_depth = np.diff([d for d, t in zip(range(len(ts)), ts)
if t]).tolist()
per_stage_width = np.unique(per_stage_width).tolist()
per_stage_groups = [groups for _ in range(len(per_stage_width))]
per_stage_groups = [
min(per_g, per_w)
for per_g, per_w in zip(per_stage_groups, per_stage_width)
]
per_stage_width = [
int(round(per_w / per_g) * per_g)
for per_w, per_g in zip(per_stage_width, per_stage_groups)
]
return stem_width, has_se, per_stage_width, per_stage_depth, per_stage_groups
class ConvBnActBlock(nn.Module):
def __init__(self,
inplanes,
planes,
kernel_size,
stride,
padding,
groups=1,
has_bn=True,
has_act=True):
super().__init__()
self.has_bn = has_bn
self.has_act = has_act
self.conv = nn.Conv2d(inplanes,
planes,
kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=False)
if self.has_bn:
self.bn = nn.BatchNorm2d(planes)
if self.has_act:
self.act = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
if self.has_bn:
x = self.bn(x)
if self.has_act:
x = self.act(x)
return x
class SeBlock(nn.Module):
def __init__(self, inplanes, reduction=4):
super(SeBlock, self).__init__()
squeezed_planes = max(1, int(inplanes // reduction))
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.conv1 = nn.Conv2d(inplanes,
squeezed_planes,
kernel_size=1,
stride=1,
padding=0,
bias=True)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(squeezed_planes,
inplanes,
kernel_size=1,
stride=1,
padding=0,
bias=True)
self.sigmoid = nn.Sigmoid()
def forward(self, inputs):
x = self.avgpool(inputs)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.sigmoid(x)
x = inputs * x
return x
class XBlock(nn.Module):
def __init__(self,
inplanes,
planes,
stride,
groups,
neck_ratio=1,
has_se=True,
downsample=False):
super(XBlock, self).__init__()
self.has_se = has_se
self.downsample = downsample
if self.downsample:
self.downsample_layer = ConvBnActBlock(inplanes,
planes,
kernel_size=1,
stride=stride,
padding=0,
groups=1,
has_bn=True,
has_act=False)
self.conv1 = ConvBnActBlock(inplanes,
planes,
kernel_size=1,
stride=1,
padding=0,
groups=1,
has_bn=True,
has_act=True)
self.conv2 = ConvBnActBlock(planes,
planes // neck_ratio,
kernel_size=3,
stride=stride,
padding=1,
groups=groups,
has_bn=True,
has_act=True)
self.conv3 = ConvBnActBlock(planes // neck_ratio,
planes,
kernel_size=1,
stride=1,
padding=0,
groups=1,
has_bn=True,
has_act=False)
self.relu = nn.ReLU(inplace=True)
if self.has_se:
self.se_block = SeBlock(planes)
def forward(self, inputs):
x = self.conv1(inputs)
x = self.conv2(x)
x = self.conv3(x)
if self.has_se:
x = self.se_block(x)
if self.downsample:
inputs = self.downsample_layer(inputs)
x += inputs
x = self.relu(x)
return x
class RegNet(nn.Module):
def __init__(self, regnet_type, num_classes=1000):
super(RegNet, self).__init__()
stem_width, has_se, per_stage_width, per_stage_depth, per_stage_groups = get_regnet_config(
RegnetY_types_config[regnet_type], q=8)
self.stem_width = stem_width
self.has_se = has_se
self.per_stage_width = per_stage_width
self.per_stage_depth = per_stage_depth
self.per_stage_groups = per_stage_groups
assert len(self.per_stage_width) == len(self.per_stage_depth)
assert len(self.per_stage_depth) == len(self.per_stage_groups)
self.conv1 = ConvBnActBlock(3,
self.stem_width,
kernel_size=3,
stride=2,
padding=1,
groups=1,
has_bn=True,
has_act=True)
self.layer1 = self.make_layer(self.stem_width,
self.per_stage_width[0],
stride=2,
block_num=self.per_stage_depth[0],
group_num=self.per_stage_groups[0],
has_se=self.has_se)
self.layer2 = self.make_layer(self.per_stage_width[0],
self.per_stage_width[1],
stride=2,
block_num=self.per_stage_depth[1],
group_num=self.per_stage_groups[1],
has_se=self.has_se)
self.layer3 = self.make_layer(self.per_stage_width[1],
self.per_stage_width[2],
stride=2,
block_num=self.per_stage_depth[2],
group_num=self.per_stage_groups[2],
has_se=self.has_se)
self.layer4 = self.make_layer(self.per_stage_width[2],
self.per_stage_width[3],
stride=2,
block_num=self.per_stage_depth[3],
group_num=self.per_stage_groups[3],
has_se=self.has_se)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(self.per_stage_width[3], num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight,
mode='fan_out',
nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def make_layer(self, inplanes, planes, stride, block_num, group_num,
has_se):
layers = []
for block_index in range(block_num):
downsample = True if block_index == 0 and (
stride != 1 or inplanes != planes) else False
inplanes = planes if block_index > 0 else inplanes
stride = 1 if block_index > 0 else stride
layers.append(
XBlock(inplanes,
planes,
stride=stride,
groups=group_num,
has_se=has_se,
downsample=downsample))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def _regnet(arch, pretrained, progress, **kwargs):
model = RegNet(arch, **kwargs)
# only load state_dict()
if pretrained:
model.load_state_dict(
torch.load(model_urls[arch], map_location=torch.device('cpu')))
return model
def RegNetY_200MF(pretrained=False, progress=True, **kwargs):
return _regnet('RegNetY_200MF', pretrained, progress, **kwargs)
def RegNetY_400MF(pretrained=False, progress=True, **kwargs):
return _regnet('RegNetY_400MF', pretrained, progress, **kwargs)
def RegNetY_600MF(pretrained=False, progress=True, **kwargs):
return _regnet('RegNetY_600MF', pretrained, progress, **kwargs)
def RegNetY_800MF(pretrained=False, progress=True, **kwargs):
return _regnet('RegNetY_800MF', pretrained, progress, **kwargs)
def RegNetY_1_6GF(pretrained=False, progress=True, **kwargs):
return _regnet('RegNetY_1_6GF', pretrained, progress, **kwargs)
def RegNetY_3_2GF(pretrained=False, progress=True, **kwargs):
return _regnet('RegNetY_3_2GF', pretrained, progress, **kwargs)
def RegNetY_4_0GF(pretrained=False, progress=True, **kwargs):
return _regnet('RegNetY_4_0GF', pretrained, progress, **kwargs)
def RegNetY_6_4GF(pretrained=False, progress=True, **kwargs):
return _regnet('RegNetY_6_4GF', pretrained, progress, **kwargs)
def RegNetY_8_0GF(pretrained=False, progress=True, **kwargs):
return _regnet('RegNetY_8_0GF', pretrained, progress, **kwargs)
def RegNetY_12GF(pretrained=False, progress=True, **kwargs):
return _regnet('RegNetY_12GF', pretrained, progress, **kwargs)
def RegNetY_16GF(pretrained=False, progress=True, **kwargs):
return _regnet('RegNetY_16GF', pretrained, progress, **kwargs)
def RegNetY_32GF(pretrained=False, progress=True, **kwargs):
return _regnet('RegNetY_32GF', pretrained, progress, **kwargs)
RegNet的訓練
對於RegNet,我們基本上遵循官方的訓練超參數:
config.py文件如下:
import torchvision.transforms as transforms
import torchvision.datasets as datasets
class Config(object):
log = './log' # Path to save log
checkpoint_path = './checkpoints' # Path to store checkpoint model
resume = './checkpoints/latest.pth' # load checkpoint model
evaluate = None # evaluate model path
train_dataset_path = '/data/ILSVRC2012/train/'
val_dataset_path = '/data/ILSVRC2012/val/'
network = "RegNetY_800MF"
pretrained = False
num_classes = 1000
seed = 0
input_image_size = 224
scale = 256 / 224
train_dataset = datasets.ImageFolder(
train_dataset_path,
transforms.Compose([
transforms.RandomResizedCrop(input_image_size),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
]))
val_dataset = datasets.ImageFolder(
val_dataset_path,
transforms.Compose([
transforms.Resize(int(input_image_size * scale)),
transforms.CenterCrop(input_image_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
]))
epochs = 100
warm_up_epochs = 5
batch_size = 256
accumulation_steps = 1
lr = 0.2
weight_decay = 5e-5
momentum = 0.9
num_workers = 8
print_interval = 100
apex = False
train.py文件如下:
import sys
import os
import argparse
import math
import random
import time
import warnings
BASE_DIR = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(BASE_DIR)
warnings.filterwarnings('ignore')
from apex import amp
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
from thop import profile
from thop import clever_format
from torch.utils.data import DataLoader
from config import Config
from public.imagenet import models
from public.imagenet.utils import DataPrefetcher, get_logger, AverageMeter, accuracy
def parse_args():
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('--network',
type=str,
default=Config.network,
help='name of network')
parser.add_argument('--lr',
type=float,
default=Config.lr,
help='learning rate')
parser.add_argument('--momentum',
type=float,
default=Config.momentum,
help='momentum')
parser.add_argument('--weight_decay',
type=float,
default=Config.weight_decay,
help='weight decay')
parser.add_argument('--epochs',
type=int,
default=Config.epochs,
help='num of training epochs')
parser.add_argument('--warm_up_epochs',
type=int,
default=Config.warm_up_epochs,
help='num of warm up epochs')
parser.add_argument('--batch_size',
type=int,
default=Config.batch_size,
help='batch size')
parser.add_argument('--accumulation_steps',
type=int,
default=Config.accumulation_steps,
help='gradient accumulation steps')
parser.add_argument('--pretrained',
type=bool,
default=Config.pretrained,
help='load pretrained model params or not')
parser.add_argument('--num_classes',
type=int,
default=Config.num_classes,
help='model classification num')
parser.add_argument('--input_image_size',
type=int,
default=Config.input_image_size,
help='input image size')
parser.add_argument('--num_workers',
type=int,
default=Config.num_workers,
help='number of worker to load data')
parser.add_argument('--resume',
type=str,
default=Config.resume,
help='put the path to resuming file if needed')
parser.add_argument('--checkpoints',
type=str,
default=Config.checkpoint_path,
help='path for saving trained models')
parser.add_argument('--log',
type=str,
default=Config.log,
help='path to save log')
parser.add_argument('--evaluate',
type=str,
default=Config.evaluate,
help='path for evaluate model')
parser.add_argument('--seed', type=int, default=Config.seed, help='seed')
parser.add_argument('--print_interval',
type=bool,
default=Config.print_interval,
help='print interval')
parser.add_argument('--apex',
type=bool,
default=Config.apex,
help='use apex or not')
return parser.parse_args()
def train(train_loader, model, criterion, optimizer, scheduler, epoch, logger,
args):
top1 = AverageMeter()
top5 = AverageMeter()
losses = AverageMeter()
# switch to train mode
model.train()
iters = len(train_loader.dataset) // args.batch_size
prefetcher = DataPrefetcher(train_loader)
inputs, labels = prefetcher.next()
iter_index = 1
while inputs is not None:
inputs, labels = inputs.cuda(), labels.cuda()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss = loss / args.accumulation_steps
if args.apex:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
if iter_index % args.accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
# measure accuracy and record loss
acc1, acc5 = accuracy(outputs, labels, topk=(1, 5))
top1.update(acc1.item(), inputs.size(0))
top5.update(acc5.item(), inputs.size(0))
losses.update(loss.item(), inputs.size(0))
inputs, labels = prefetcher.next()
if iter_index % args.print_interval == 0:
logger.info(
f"train: epoch {epoch:0>3d}, iter [{iter_index:0>4d}, {iters:0>4d}], lr: {scheduler.get_lr()[0]:.6f}, top1 acc: {acc1.item():.2f}%, top5 acc: {acc5.item():.2f}%, loss_total: {loss.item():.2f}"
)
iter_index += 1
scheduler.step()
return top1.avg, top5.avg, losses.avg
def validate(val_loader, model, args):
batch_time = AverageMeter()
data_time = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
# switch to evaluate mode
model.eval()
with torch.no_grad():
end = time.time()
for inputs, labels in val_loader:
data_time.update(time.time() - end)
inputs, labels = inputs.cuda(), labels.cuda()
outputs = model(inputs)
acc1, acc5 = accuracy(outputs, labels, topk=(1, 5))
top1.update(acc1.item(), inputs.size(0))
top5.update(acc5.item(), inputs.size(0))
batch_time.update(time.time() - end)
end = time.time()
throughput = 1.0 / (batch_time.avg / inputs.size(0))
return top1.avg, top5.avg, throughput
def main(logger, args):
if not torch.cuda.is_available():
raise Exception("need gpu to train network!")
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
cudnn.deterministic = True
gpus = torch.cuda.device_count()
logger.info(f'use {gpus} gpus')
logger.info(f"args: {args}")
cudnn.benchmark = True
cudnn.enabled = True
start_time = time.time()
# dataset and dataloader
logger.info('start loading data')
train_loader = DataLoader(Config.train_dataset,
batch_size=args.batch_size,
shuffle=True,
pin_memory=True,
num_workers=args.num_workers)
val_loader = DataLoader(Config.val_dataset,
batch_size=args.batch_size,
shuffle=False,
pin_memory=True,
num_workers=args.num_workers)
logger.info('finish loading data')
logger.info(f"creating model '{args.network}'")
model = models.__dict__[args.network](**{
"pretrained": args.pretrained,
"num_classes": args.num_classes,
})
flops_input = torch.randn(1, 3, args.input_image_size,
args.input_image_size)
flops, params = profile(model, inputs=(flops_input, ))
flops, params = clever_format([flops, params], "%.3f")
logger.info(f"model: '{args.network}', flops: {flops}, params: {params}")
for name, param in model.named_parameters():
logger.info(f"{name},{param.requires_grad}")
model = model.cuda()
criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(model.parameters(),
args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
# warm_up_with_cosine_lr
warm_up_with_cosine_lr = lambda epoch: epoch / args.warm_up_epochs if epoch <= args.warm_up_epochs else 0.5 * (
math.cos((epoch - args.warm_up_epochs) /
(args.epochs - args.warm_up_epochs) * math.pi) + 1)
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, lr_lambda=warm_up_with_cosine_lr)
if args.apex:
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
model = nn.DataParallel(model)
if args.evaluate:
if not os.path.isfile(args.evaluate):
raise Exception(
f"{args.resume} is not a file, please check it again")
logger.info('start only evaluating')
logger.info(f"start resuming model from {args.evaluate}")
checkpoint = torch.load(args.evaluate,
map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])
acc1, acc5, throughput = validate(val_loader, model, args)
logger.info(
f"epoch {checkpoint['epoch']:0>3d}, top1 acc: {acc1:.2f}%, top5 acc: {acc5:.2f}%, throughput: {throughput:.2f}sample/s"
)
return
start_epoch = 1
# resume training
if os.path.exists(args.resume):
logger.info(f"start resuming model from {args.resume}")
checkpoint = torch.load(args.resume, map_location=torch.device('cpu'))
start_epoch += checkpoint['epoch']
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
logger.info(
f"finish resuming model from {args.resume}, epoch {checkpoint['epoch']}, "
f"loss: {checkpoint['loss']:3f}, lr: {checkpoint['lr']:.6f}, "
f"top1_acc: {checkpoint['acc1']}%")
if not os.path.exists(args.checkpoints):
os.makedirs(args.checkpoints)
logger.info('start training')
for epoch in range(start_epoch, args.epochs + 1):
acc1, acc5, losses = train(train_loader, model, criterion, optimizer,
scheduler, epoch, logger, args)
logger.info(
f"train: epoch {epoch:0>3d}, top1 acc: {acc1:.2f}%, top5 acc: {acc5:.2f}%, losses: {losses:.2f}"
)
acc1, acc5, throughput = validate(val_loader, model, args)
logger.info(
f"val: epoch {epoch:0>3d}, top1 acc: {acc1:.2f}%, top5 acc: {acc5:.2f}%, throughput: {throughput:.2f}sample/s"
)
# remember best prec@1 and save checkpoint
torch.save(
{
'epoch': epoch,
'acc1': acc1,
'loss': losses,
'lr': scheduler.get_lr()[0],
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
}, os.path.join(args.checkpoints, 'latest.pth'))
if epoch == args.epochs:
torch.save(
model.module.state_dict(),
os.path.join(
args.checkpoints,
"{}-epoch{}-acc{}.pth".format(args.network, epoch, acc1)))
training_time = (time.time() - start_time) / 3600
logger.info(
f"finish training, total training time: {training_time:.2f} hours")
if __name__ == '__main__':
args = parse_args()
logger = get_logger(__name__, args.log)
main(logger, args)
對RegNet訓練時,我們總共訓練100個epoch,在前5個epoch使用warm up,並且使用cosine learning rate衰減。
RegNet訓練結果
Network | Top-1 error |
---|---|
RegNetY_200MF | 29.904 |
RegNetY_400MF | 26.210 |
RegNetY_600MF | 25.276 |
RegNetY_800MF | 24.006 |
RegNetY_1.6GF | 22.692 |
RegNetY_3.2GF | 21.092 |
RegNetY_4.0GF | 21.864 |
RegNetY_6.4GF | 21.230 |