darts論文鏈接:https://arxiv.org/pdf/1806.09055.pdf
darts源碼鏈接:https://github.com/quark0/darts
search部分
'''
train_search.py
#數據準備(cifar10)。
搜索時,從cifar10的訓練集中按照1:1重新劃分訓練集和驗證集
'''
train_transform, valid_transform = utils._data_transforms_cifar10(args)
train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
#論文中 args.train_portion 取0.5
num_train = len(train_data)
indices = list(range(num_train))
split = int(np.floor(args.train_portion * num_train))
train_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
pin_memory=True, num_workers=2)
valid_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
pin_memory=True, num_workers=2)
'''
train_search.py
搜索網絡
損失函數:交叉熵
優化器:帶動量的SGD
學習率調整策略:餘弦退火調整學習率 CosineAnnealingLR
'''
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(
model.parameters(),
args.learning_rate,
momentum=args.momentum,
weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs), eta_min=args.learning_rate_min)
'''
train_search.py
構建搜索網絡
構建Architect優化
'''
# in model_search.py
model = Network(args.init_channels, CIFAR_CLASSES, args.layers, criterion)
# in architect.py
architect = Architect(model, args)
'''
model_search.py
論文中
# C :16
# num_classes :2
# criterion
# layers:8
'''
class Network(nn.Module):
def __init__(self, C, num_classes, layers, criterion, steps=4, multiplier=4, stem_multiplier=3):
super(Network, self).__init__()
self._C = C
self._num_classes = num_classes
self._layers = layers
self._criterion = criterion
self._steps = steps
self._multiplier = multiplier
C_curr = stem_multiplier*C
# stem 開始conv+bn
self.stem = nn.Sequential(
nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
nn.BatchNorm2d(C_curr)
)
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
self.cells = nn.ModuleList()
reduction_prev = False
# 對每個layers,8個cell
# 分爲normal cell和reduction cell (通道加倍)
for i in range(layers):
if i in [layers//3, 2*layers//3]:
# 共8個cell ,取2-5個cell是作爲reduction cell,經過reduction cell,通道加倍
C_curr *= 2
reduction = True
else:
reduction = False
cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
reduction_prev = reduction
self.cells += [cell]
C_prev_prev, C_prev = C_prev, multiplier*C_curr
# cell堆疊之後,後接分類
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
# 初始化alpha
self._initialize_alphas()
# 新建network,copy alpha參數
def new(self):
model_new = Network(self._C, self._num_classes, self._layers, self._criterion).cuda()
for x, y in zip(model_new.arch_parameters(), self.arch_parameters()):
x.data.copy_(y.data)
return model_new
def forward(self, input):
s0 = s1 = self.stem(input)
for i, cell in enumerate(self.cells):
#reduction cell 和normal cell 的 共享參數aplha 不同
if cell.reduction:
# softmax 歸一化,14*8,對每一個連接之間的8個op操作進行softmax
weights = F.softmax(self.alphas_reduce, dim=-1)
else:
weights = F.softmax(self.alphas_normal, dim=-1)
# 每個cell之間的連接,s0來自上上個cell輸出,s1來自上一個cell的輸出
s0, s1 = s1, cell(s0, s1, weights)
out = self.global_pooling(s1)
logits = self.classifier(out.view(out.size(0),-1))
return logits
def _loss(self, input, target):
logits = self(input)
return self._criterion(logits, target)
# 初始化 alpha
def _initialize_alphas(self):
# 14 個連接,4箇中間節點 2+3+4+5
k = sum(1 for i in range(self._steps) for n in range(2+i))
num_ops = len(PRIMITIVES)
#14,8
# normal cell
# reduction cell
self.alphas_normal = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)
self.alphas_reduce = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)
self._arch_parameters = [
self.alphas_normal,
self.alphas_reduce,
]
def arch_parameters(self):
return self._arch_parameters
def genotype(self):
def _parse(weights):
gene = []
n = 2
start = 0
for i in range(self._steps):
# 對於每一箇中間節點
end = start + n
# 每個節點對應連接的所有權重 (2,3,4,5)
W = weights[start:end].copy()
#對於每個節點,根據其與其他節點的連接權重的最大值,來選擇最優的2個連接方式(與哪兩個節點之間有連接)
#注意這裏只是選擇連接的對應節點,並沒有確定對應的連接op,後續確定
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[:2]
# 對於最優的兩個連接邊,分別選擇最優的連接op
# 這個選擇方式,感覺太粗糙了。假設也存在從alpha權重上來看,連接1的第2優的op,比連接2的第1優的op要好。這種操作避免了同一個邊的多個op的存在,其實我覺得這種存在也是合理的吧。
# 後續有論文對這個選擇策略進行改進。如fair-darts,後續blog會講
for j in edges:
k_best = None
for k in range(len(W[j])):
if k != PRIMITIVES.index('none'):
if k_best is None or W[j][k] > W[j][k_best]:
k_best = k
# 記錄下最好的op,和對應的連接邊(與哪個節點相連)
# 對於每個節點,選擇兩個邊和對應op,即一個cell有2*4=8個操作,定義死了,不夠靈活!
gene.append((PRIMITIVES[k_best], j))
start = end
n += 1
return gene
# 歸一化,基於策略選取 每個連接之間最優的操作
gene_normal = _parse(F.softmax(self.alphas_normal, dim=-1).data.cpu().numpy())
gene_reduce = _parse(F.softmax(self.alphas_reduce, dim=-1).data.cpu().numpy())
# 2,6
concat = range(2+self._steps-self._multiplier, self._steps+2)
genotype = Genotype(
normal=gene_normal, normal_concat=concat,
reduce=gene_reduce, reduce_concat=concat
)
return genotype
'''
model_search.py
cell的實現,參數共享,分爲normal cell 和 reduction cell
經過reduction cell 特徵圖減半
'''
# 對於 每一條連接
class MixedOp(nn.Module):
def __init__(self, C, stride):
super(MixedOp, self).__init__()
self._ops = nn.ModuleList()
# 8種op操作
for primitive in PRIMITIVES:
# 計算每一種操作
op = OPS[primitive](C, stride, False)
# 如果操作與pool相關,後接bn
if 'pool' in primitive:
op = nn.Sequential(op, nn.BatchNorm2d(C, affine=False))
self._ops.append(op)
def forward(self, x, weights):
# 每一條連接 所有op,sum。weights每一條連接的每一個op的權重。
return sum(w * op(x) for w, op in zip(weights, self._ops))
'''
steps:
multiplier:
C_prev_prev :上上個cell通道
C_prev:上個cell的通道
reduction:是否是reduction cell
reduction_prev : 上一個cell是否是 reduction cell
'''
class Cell(nn.Module):
def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev):
super(Cell, self).__init__()
self.reduction = reduction
# 對於輸入 s0,來自於上上個cell的輸出, 通道數C_prev_prev->C
if reduction_prev:
# 如果這個cell上面的cell是reduction cell
# FactorizedReduce,通道數不變C_prev_prev->C,featuremap 減半
# 一個conv通道減半C_out // 2,featuremap減半。兩個conv,concat
# 這個featuremap減半 與 上一個reduction cell 輸出的減半的featuremap 規格一樣了
self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False)
else:
# featuremap大小不變,通道數C_prev->C
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
# 對於輸入 s1,來自於上個cell的輸出,ReLUConvBN->relu+conv+bn
# featuremap大小不變,通道數C_prev->C
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
self._steps = steps
self._multiplier = multiplier
self._ops = nn.ModuleList()
self._bns = nn.ModuleList()
# 對於4箇中間節點
for i in range(self._steps):
for j in range(2+i):
# 如果是reduction cell ,對於一開始的輸入 s0,s1到第一個中間節點的連接,stride=2,featuremap實現減半
stride = 2 if reduction and j < 2 else 1
# 對於14個連接,MixedOp 8個操作
op = MixedOp(C, stride)
self._ops.append(op)
def forward(self, s0, s1, weights):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
#每個cell之間的連接,s0來自上上個cell輸出,s1來自上一個cell的輸出
states = [s0, s1]
offset = 0
for i in range(self._steps):
# 對於每一個節點,計算所有到它的連接的featuremap和。
#[s0,s1] ops[0](s0,weights[0])+ops[1](s0,weights[1])
#[s0,s1,sa] ops[2](s0,weights[2])+ops[3](s1,weights[3])+ops[4](sa,weights[4])
#[s0,s1,sa,sb] (5,6,7,8)
#[s0,s1,sa,sb,sc] (9,10,11,12,13)
s = sum(self._ops[offset+j](h, weights[offset+j]) for j, h in enumerate(states))
offset += len(states)
states.append(s)
#[s0,s1,sa,sb,sc,sd]
# concat 後面四個節點(sa,sb,sc,sd)的輸出,作爲整體輸出,4*C
return torch.cat(states[-self._multiplier:], dim=1)
approximate architecture gradient
第一步:更新$\alpha$ 第二步:更新$\omega$ architecture search的目標就是通過最小化驗證集的loss $L_{val}(w^*,α^*)$來得到最優的$\alpha$。$w^*$是通過最小化訓練集loss $L_{train}(w,α^*)$得到的 這是一個bilevel 優化問題。(1)優化目標
(2)近似優化目標,一步優化
(3)應用momentum,鏈式法則
()
公式2和3的區別是,2是複合函數,3是對求偏導
(4)泰勒公式進一步近似
()
'''
architect.py
#優化alpha參數
'''
def _concat(xs):
# 把x view成一行,然後cat成n行
return torch.cat([x.view(-1) for x in xs])
#architect.step(input, target, input_search, target_search, lr, optimizer, unrolled=args.unrolled)
class Architect(object):
def __init__(self, model, args):
self.network_momentum = args.momentum
self.network_weight_decay = args.weight_decay
self.model = model
# 僅優化 arch_parameters
self.optimizer = torch.optim.Adam(self.model.arch_parameters(),
lr=args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay)
def _compute_unrolled_model(self, input, target, eta, network_optimizer):
# 對omega參數,Ltrain loss
# theta = theta + v + weight_decay * theta
# w − ξ*dwLtrain(w, α)
loss = self.model._loss(input, target)
# n個參數變成n行,需更新的參數theta
theta = _concat(self.model.parameters()).data
try:
# 增加動量
moment = _concat(network_optimizer.state[v]['momentum_buffer'] for v in self.model.parameters()).mul_(self.network_momentum)
except:
# 不加動量
moment = torch.zeros_like(theta)
dtheta = _concat(torch.autograd.grad(loss, self.model.parameters())).data + self.network_weight_decay*theta
# w − ξ*dwLtrain(w, α)
unrolled_model = self._construct_model_from_theta(theta.sub(eta, moment+dtheta))
return unrolled_model
def step(self, input_train, target_train, input_valid, target_valid, eta, network_optimizer, unrolled):
# 清除之前的更新參數值梯度
self.optimizer.zero_grad()
if unrolled:
#用論文的提出的方法
self._backward_step_unrolled(input_train, target_train, input_valid, target_valid, eta, network_optimizer)
else:
# 普通優化,交替優化,僅優化alpha,簡單求導
self._backward_step(input_valid, target_valid)
self.optimizer.step() # optimizer存了alpha參數的指針
def _backward_step(self, input_valid, target_valid):
# 反向傳播,計算梯度
loss = self.model._loss(input_valid, target_valid)
loss.backward()
def _backward_step_unrolled(self, input_train, target_train, input_valid, target_valid, eta, network_optimizer):
# 計算 w' = w − ξ*dwLtrain(w, α)
# unrolled_model中已經計算完w'
unrolled_model = self._compute_unrolled_model(input_train, target_train, eta, network_optimizer)
# 計算 dαLval(w',α)
# 對做了一次更新後的w的unrolled_model求驗證集的損失,Lval,以用來對α進行更新
unrolled_loss = unrolled_model._loss(input_valid, target_valid)
unrolled_loss.backward()
# dα Lval(w',α)
dalpha = [v.grad for v in unrolled_model.arch_parameters()]
# dw'Lval(w',α)
vector = [v.grad.data for v in unrolled_model.parameters()]
# 計算(dαLtrain(w+,α)-dαLtrain(w-,α))/(2*epsilon)
implicit_grads = self._hessian_vector_product(vector, input_train, target_train)
# dαLval(w',α)-(dαLtrain(w+,α)-dαLtrain(w-,α))/(2*epsilon)
for g, ig in zip(dalpha, implicit_grads):
g.data.sub_(eta, ig.data)
# 對alpha進行更新
for v, g in zip(self.model.arch_parameters(), dalpha):
if v.grad is None:
v.grad = Variable(g.data)
else:
v.grad.data.copy_(g.data)
def _construct_model_from_theta(self, theta):
# 新建network,copy alpha參數
model_new = self.model.new()
model_dict = self.model.state_dict()
# 按照之前的大小,copy theta參數
params, offset = {}, 0
for k, v in self.model.named_parameters():
v_length = np.prod(v.size())
params[k] = theta[offset: offset+v_length].view(v.size())
offset += v_length
assert offset == len(theta)
model_dict.update(params)
model_new.load_state_dict(model_dict)
# 返回 參數更新爲做一次反向傳播後的值 的模型
return model_new.cuda()
# 計算(dαLtrain(w+,α)-dαLtrain(w-,α))/(2*epsilon)
# w+ = w+dw'Lval(w',α)*epsilon
# w- = w-dw'Lval(w',α)*epsilon
def _hessian_vector_product(self, vector, input, target, r=1e-2):
R = r / _concat(vector).norm()
# w+ = w+dw'Lval(w',α)*epsilon
for p, v in zip(self.model.parameters(), vector):
p.data.add_(R, v)
# dαLtrain(w+,α)
loss = self.model._loss(input, target)
grads_p = torch.autograd.grad(loss, self.model.arch_parameters())
# w- = w-dw'Lval(w',α)*epsilon
for p, v in zip(self.model.parameters(), vector):
p.data.sub_(2*R, v)
# dαLtrain(w-,α)
loss = self.model._loss(input, target)
grads_n = torch.autograd.grad(loss, self.model.arch_parameters())
for p, v in zip(self.model.parameters(), vector):
p.data.add_(R, v)
return [(x-y).div_(2*R) for x, y in zip(grads_p, grads_n)]
'''
train_search.py
training && validation
'''
# training
train_acc, train_obj = train(train_queue, valid_queue, model, architect, criterion, optimizer, lr)
# validation
valid_acc, valid_obj = infer(valid_queue, model, criterion)
def train(train_queue, valid_queue, model, architect, criterion, optimizer, lr):
#...
#...
for step, (input, target) in enumerate(train_queue):
model.train()
input = Variable(input, requires_grad=False).cuda()
target = Variable(target, requires_grad=False).cuda()
# get a random minibatch from the search queue with replacement
input_search, target_search = next(iter(valid_queue))
input_search = Variable(input_search, requires_grad=False).cuda()
target_search = Variable(target_search, requires_grad=False).cuda()
# 第一步 優化alpha,搜索參數
# darts是交替優化的.
architect.step(input, target, input_search, target_search, lr, optimizer, unrolled=args.unrolled)
# 第二步,優化omega,網絡卷積參數
optimizer.zero_grad()
logits = model(input)
loss = criterion(logits, target)
loss.backward()
nn.utils.clip_grad_norm(model.parameters(), args.grad_clip)
optimizer.step()
prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
objs.update(loss.data[0], n)
top1.update(prec1.data[0], n)
top5.update(prec5.data[0], n)
#...
#...
return top1.avg, objs.avg
def infer(valid_queue, model, criterion):
#...
#...
model.eval()
for step, (input, target) in enumerate(valid_queue):
input = Variable(input, volatile=True).cuda()
target = Variable(target, volatile=True).cuda()
logits = model(input)
loss = criterion(logits, target)
prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
n = input.size(0)
objs.update(loss.item(), n)
top1.update(prec1.item(), n)
top5.update(prec5.item(), n)
#...
#...
return top1.avg, objs.avg
train部分
'''
train.py
#數據準備(cifar10)。
從cifar10的訓練集和驗證集中直接取
#網絡參數
搜索網絡 訓練網絡
init_channels: 16 36
layers: 8 20
訓練網絡多一個auxiliary
'''
'''
train.py
子網加載
'''
from model import NetworkCIFAR as Network
genotype = eval("genotypes.%s" % args.arch)
model = Network(args.init_channels, CIFAR_CLASSES, args.layers, args.auxiliary, genotype)
'''
model.py
子網結構加載
比搜索網絡多一個auxiliary,cell也是直接加載權重
'''
class Cell(nn.Module):
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
super(Cell, self).__init__()
print(C_prev_prev, C_prev, C)
if reduction_prev:
self.preprocess0 = FactorizedReduce(C_prev_prev, C)
else:
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
# 加載genotype
if reduction:
op_names, indices = zip(*genotype.reduce)
concat = genotype.reduce_concat
else:
op_names, indices = zip(*genotype.normal)
concat = genotype.normal_concat
self._compile(C, op_names, indices, concat, reduction)
def _compile(self, C, op_names, indices, concat, reduction):
assert len(op_names) == len(indices)
self._steps = len(op_names) // 2
self._concat = concat
self.multiplier = len(concat)
self._ops = nn.ModuleList()
for name, index in zip(op_names, indices):
stride = 2 if reduction and index < 2 else 1
op = OPS[name](C, stride, True)
self._ops += [op]
self._indices = indices
def forward(self, s0, s1, drop_prob):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
for i in range(self._steps):
h1 = states[self._indices[2*i]]
h2 = states[self._indices[2*i+1]]
op1 = self._ops[2*i]
op2 = self._ops[2*i+1]
h1 = op1(h1)
h2 = op2(h2)
if self.training and drop_prob > 0.:
if not isinstance(op1, Identity):
h1 = drop_path(h1, drop_prob)
if not isinstance(op2, Identity):
h2 = drop_path(h2, drop_prob)
s = h1 + h2
states += [s]
return torch.cat([states[i] for i in self._concat], dim=1)
class AuxiliaryHeadCIFAR(nn.Module):
def __init__(self, C, num_classes):
"""assuming input size 8x8"""
super(AuxiliaryHeadCIFAR, self).__init__()
self.features = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, 2, bias=False),
nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.classifier = nn.Linear(768, num_classes)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.view(x.size(0),-1))
return x
'''
網絡結構
model = Network(args.init_channels, CIFAR_CLASSES, args.layers, args.auxiliary, genotype)
'''
class NetworkCIFAR(nn.Module):
def __init__(self, C, num_classes, layers, auxiliary, genotype):
super(NetworkCIFAR, self).__init__()
self._layers = layers
self._auxiliary = auxiliary
stem_multiplier = 3
C_curr = stem_multiplier*C
# 主幹
self.stem = nn.Sequential(
nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
nn.BatchNorm2d(C_curr)
)
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
self.cells = nn.ModuleList()
reduction_prev = False
for i in range(layers):
if i in [layers//3, 2*layers//3]:
C_curr *= 2
reduction = True
else:
reduction = False
cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
reduction_prev = reduction
self.cells += [cell]
C_prev_prev, C_prev = C_prev, cell.multiplier*C_curr
# 比搜索網絡多一個auxiliary
if i == 2*layers//3:
C_to_auxiliary = C_prev
if auxiliary:
self.auxiliary_head = AuxiliaryHeadCIFAR(C_to_auxiliary, num_classes)
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
def forward(self, input):
logits_aux = None
s0 = s1 = self.stem(input)
for i, cell in enumerate(self.cells):
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
if i == 2*self._layers//3:
if self._auxiliary and self.training:
logits_aux = self.auxiliary_head(s1)
out = self.global_pooling(s1)
logits = self.classifier(out.view(out.size(0),-1))
return logits, logits_aux
'''
train.py
子網訓練
'''
for epoch in range(args.epochs):
scheduler.step()
# ...
model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
train_acc, train_obj = train(train_queue, model, criterion, optimizer)
valid_acc, valid_obj = infer(valid_queue, model, criterion)
# ...
def train(train_queue, model, criterion, optimizer):
objs = utils.AvgrageMeter()
top1 = utils.AvgrageMeter()
top5 = utils.AvgrageMeter()
model.train()
for step, (input, target) in enumerate(train_queue):
input = Variable(input).cuda()
target = Variable(target).cuda()
optimizer.zero_grad()
# logits_aux
logits, logits_aux = model(input)
loss = criterion(logits, target)
# 採用附加結構loss*權重
if args.auxiliary:
loss_aux = criterion(logits_aux, target)
loss += args.auxiliary_weight*loss_aux
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step()
prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
n = input.size(0)
objs.update(loss.item(), n)
top1.update(prec1.item(), n)
top5.update(prec5.item(), n)
return top1.avg, objs.avg
def infer(valid_queue, model, criterion):
objs = utils.AvgrageMeter()
top1 = utils.AvgrageMeter()
top5 = utils.AvgrageMeter()
model.eval()
for step, (input, target) in enumerate(valid_queue):
input = Variable(input, volatile=True).cuda()
target = Variable(target, volatile=True).cuda()
logits, _ = model(input)
loss = criterion(logits, target)
prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
n = input.size(0)
objs.update(loss.item(), n)
top1.update(prec1.item(), n)
top5.update(prec5.item(), n)
if step % args.report_freq == 0:
logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
return top1.avg, objs.avg