[NAS]Darts代碼解析

darts論文鏈接:https://arxiv.org/pdf/1806.09055.pdf
darts源碼鏈接:https://github.com/quark0/darts


search部分
'''
train_search.py
#數據準備(cifar10)。
搜索時,從cifar10的訓練集中按照1:1重新劃分訓練集和驗證集
'''
train_transform, valid_transform = utils._data_transforms_cifar10(args)
train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)

#論文中 args.train_portion 取0.5
num_train = len(train_data)
indices = list(range(num_train))
split = int(np.floor(args.train_portion * num_train))

train_queue = torch.utils.data.DataLoader(
      train_data, batch_size=args.batch_size,
      sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
      pin_memory=True, num_workers=2)
valid_queue = torch.utils.data.DataLoader(
      train_data, batch_size=args.batch_size,
      sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
      pin_memory=True, num_workers=2)
      
'''
train_search.py
搜索網絡
損失函數:交叉熵
優化器:帶動量的SGD
學習率調整策略:餘弦退火調整學習率 CosineAnnealingLR
'''
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(
      model.parameters(),
      args.learning_rate,
      momentum=args.momentum,
      weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs), eta_min=args.learning_rate_min)

'''
train_search.py
構建搜索網絡
構建Architect優化
'''
# in model_search.py
model = Network(args.init_channels, CIFAR_CLASSES, args.layers, criterion)
# in architect.py
architect = Architect(model, args)

'''
model_search.py
論文中
# C	:16
# num_classes :2
# criterion
# layers:8
'''
class Network(nn.Module):
	def __init__(self, C, num_classes, layers, criterion, steps=4, multiplier=4, stem_multiplier=3):
	    super(Network, self).__init__()
	    self._C = C
	    self._num_classes = num_classes
	    self._layers = layers
	    self._criterion = criterion
	    self._steps = steps
	    self._multiplier = multiplier

	    C_curr = stem_multiplier*C
        # stem 開始conv+bn
	    self.stem = nn.Sequential(
	      nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
	      nn.BatchNorm2d(C_curr)
	    )
 
		C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
		self.cells = nn.ModuleList()
		reduction_prev = False
		# 對每個layers,8個cell
		# 分爲normal cell和reduction cell (通道加倍)
		for i in range(layers):
			if i in [layers//3, 2*layers//3]:
				# 共8個cell ,取2-5個cell是作爲reduction cell,經過reduction cell,通道加倍
			    C_curr *= 2
			    reduction = True
		 	else:
		    	reduction = False
			cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
			reduction_prev = reduction
			self.cells += [cell]
			C_prev_prev, C_prev = C_prev, multiplier*C_curr
		
		# cell堆疊之後,後接分類
		self.global_pooling = nn.AdaptiveAvgPool2d(1)
		self.classifier = nn.Linear(C_prev, num_classes)
		
		# 初始化alpha
		self._initialize_alphas()

	# 新建network,copy alpha參數
	def new(self):
		model_new = Network(self._C, self._num_classes, self._layers, self._criterion).cuda()
		for x, y in zip(model_new.arch_parameters(), self.arch_parameters()):
			x.data.copy_(y.data)
		return model_new

	def forward(self, input):
		s0 = s1 = self.stem(input)
		for i, cell in enumerate(self.cells):
			#reduction cell 和normal cell 的 共享參數aplha 不同
	 		if cell.reduction:
	 			# softmax 歸一化,14*8,對每一個連接之間的8個op操作進行softmax
	    		weights = F.softmax(self.alphas_reduce, dim=-1)
	 		else:
	    		weights = F.softmax(self.alphas_normal, dim=-1)
	    	# 每個cell之間的連接,s0來自上上個cell輸出,s1來自上一個cell的輸出
	  		s0, s1 = s1, cell(s0, s1, weights)
		out = self.global_pooling(s1)
		logits = self.classifier(out.view(out.size(0),-1))
		return logits

	def _loss(self, input, target):
		logits = self(input)
		return self._criterion(logits, target) 
	
	# 初始化 alpha
	def _initialize_alphas(self):
		# 14 個連接,4箇中間節點 2+3+4+5
		k = sum(1 for i in range(self._steps) for n in range(2+i))
		num_ops = len(PRIMITIVES)
		#14,8
		# normal cell
		# reduction cell
		self.alphas_normal = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)
		self.alphas_reduce = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)
		self._arch_parameters = [
			self.alphas_normal,
			self.alphas_reduce,
			]

	def arch_parameters(self):
		return self._arch_parameters
	
	def genotype(self):
		def _parse(weights):
			gene = []
     		n = 2
     		start = 0
      		for i in range(self._steps):
      	        # 對於每一箇中間節點
        		end = start + n
        		# 每個節點對應連接的所有權重 (2,3,4,5)
        		W = weights[start:end].copy()
        		#對於每個節點,根據其與其他節點的連接權重的最大值,來選擇最優的2個連接方式(與哪兩個節點之間有連接)
        		#注意這裏只是選擇連接的對應節點,並沒有確定對應的連接op,後續確定
        		edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[:2]
        		# 對於最優的兩個連接邊,分別選擇最優的連接op
        		# 這個選擇方式,感覺太粗糙了。假設也存在從alpha權重上來看,連接1的第2優的op,比連接2的第1優的op要好。這種操作避免了同一個邊的多個op的存在,其實我覺得這種存在也是合理的吧。
        		# 後續有論文對這個選擇策略進行改進。如fair-darts,後續blog會講
        		for j in edges:
					k_best = None
          			for k in range(len(W[j])):
            			if k != PRIMITIVES.index('none'):
              				if k_best is None or W[j][k] > W[j][k_best]:
                				k_best = k
                	# 記錄下最好的op,和對應的連接邊(與哪個節點相連)
                	# 對於每個節點,選擇兩個邊和對應op,即一個cell有2*4=8個操作,定義死了,不夠靈活!
          			gene.append((PRIMITIVES[k_best], j))
        		start = end
        		n += 1
        	return gene

        # 歸一化,基於策略選取 每個連接之間最優的操作
		gene_normal = _parse(F.softmax(self.alphas_normal, dim=-1).data.cpu().numpy())
		gene_reduce = _parse(F.softmax(self.alphas_reduce, dim=-1).data.cpu().numpy())
		# 2,6
		concat = range(2+self._steps-self._multiplier, self._steps+2)
		genotype = Genotype(
		normal=gene_normal, normal_concat=concat,
		reduce=gene_reduce, reduce_concat=concat
		)
		return genotype
		
'''
model_search.py
cell的實現,參數共享,分爲normal cell 和 reduction cell
經過reduction cell 特徵圖減半
'''
# 對於 每一條連接
class MixedOp(nn.Module):
	def __init__(self, C, stride):
		super(MixedOp, self).__init__()
		self._ops = nn.ModuleList()
		# 8種op操作
		for primitive in PRIMITIVES:
			# 計算每一種操作
     		op = OPS[primitive](C, stride, False)
     		# 如果操作與pool相關,後接bn
      		if 'pool' in primitive:
        		op = nn.Sequential(op, nn.BatchNorm2d(C, affine=False))
      		self._ops.append(op)

	def forward(self, x, weights):
		# 每一條連接 所有op,sum。weights每一條連接的每一個op的權重。
    	return sum(w * op(x) for w, op in zip(weights, self._ops))
    	
'''
steps:
multiplier:
C_prev_prev :上上個cell通道
C_prev:上個cell的通道
reduction:是否是reduction cell
reduction_prev : 上一個cell是否是 reduction cell
'''
class Cell(nn.Module):
	def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev):
    	super(Cell, self).__init__()
    	self.reduction = reduction
    	
		# 對於輸入 s0,來自於上上個cell的輸出, 通道數C_prev_prev->C
    	if reduction_prev:
    		# 如果這個cell上面的cell是reduction cell
    		# FactorizedReduce,通道數不變C_prev_prev->C,featuremap 減半
    		# 一個conv通道減半C_out // 2,featuremap減半。兩個conv,concat
    		# 這個featuremap減半 與 上一個reduction cell 輸出的減半的featuremap 規格一樣了
      		self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False)
    	else:
    		# featuremap大小不變,通道數C_prev->C
      		self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
      	# 對於輸入 s1,來自於上個cell的輸出,ReLUConvBN->relu+conv+bn
      	# featuremap大小不變,通道數C_prev->C
		self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)

		self._steps = steps
		self._multiplier = multiplier
		self._ops = nn.ModuleList()
		self._bns = nn.ModuleList()
		# 對於4箇中間節點
		for i in range(self._steps): 
			for j in range(2+i):
			    # 如果是reduction cell ,對於一開始的輸入 s0,s1到第一個中間節點的連接,stride=2,featuremap實現減半
				stride = 2 if reduction and j < 2 else 1
				#  對於14個連接,MixedOp 8個操作
				op = MixedOp(C, stride)
				self._ops.append(op)

	def forward(self, s0, s1, weights):
		s0 = self.preprocess0(s0)
		s1 = self.preprocess1(s1)
		
		#每個cell之間的連接,s0來自上上個cell輸出,s1來自上一個cell的輸出
		states = [s0, s1]
		offset = 0
		for i in range(self._steps):
			# 對於每一個節點,計算所有到它的連接的featuremap和。
			#[s0,s1] ops[0](s0,weights[0])+ops[1](s0,weights[1])
			#[s0,s1,sa] ops[2](s0,weights[2])+ops[3](s1,weights[3])+ops[4](sa,weights[4])
			#[s0,s1,sa,sb] (5,6,7,8)
			#[s0,s1,sa,sb,sc] (9,10,11,12,13)
			s = sum(self._ops[offset+j](h, weights[offset+j]) for j, h in enumerate(states))
			offset += len(states)
			states.append(s)
			#[s0,s1,sa,sb,sc,sd]
		# concat 後面四個節點(sa,sb,sc,sd)的輸出,作爲整體輸出,4*C
		return torch.cat(states[-self._multiplier:], dim=1)
		

approximate architecture gradient

第一步:更新$\alpha$ 第二步:更新$\omega$ architecture search的目標就是通過最小化驗證集的loss $L_{val}(w^*,α^*)$來得到最優的$\alpha$。$w^*$是通過最小化訓練集loss $L_{train}(w,α^*)$得到的 這是一個bilevel 優化問題。

(1)優化目標
minαLval(ω(α),α)s.t.ω(α)=argminωLtrain(ω,α)\min \limits_{\alpha}\quad L_{val}(\omega^*(\alpha),\alpha) \\s.t.\quad \omega^*(\alpha)=argmin_{\omega}\,L_{train}(\omega,\alpha)
(2)近似優化目標,一步優化
αLval(ω(α),α)\nabla_{\alpha}L_{val}(\omega^*(\alpha),\alpha)
αLval(ωξωLtrain(ω,α),α)\approx\nabla_{\alpha}L_{val}(\omega-\xi\nabla_{\omega}L_{train}(\omega,\alpha),\alpha)
(3)應用momentum,鏈式法則
=αLval(ω,α)ξα,ω2Ltrain(ω,α)ωLval(ω,α)=\nabla_{\alpha}L_{val}(\omega',\alpha)-\xi\nabla^2_{\alpha,\omega}L_{train}(\omega,\alpha)\cdot\nabla_{\omega'}L_{val}(\omega',\alpha)
(ω=ωξωLtrain(ω,α)\omega'=\omega-\xi\nabla_{\omega}L_{train}(\omega,\alpha))
公式2和3的區別是,2是複合函數,3是對α\alpha求偏導
(4)泰勒公式進一步近似
α,ω2Ltrain(ω,α)ωLval(ω,α)\nabla^2_{\alpha,\omega}L_{train}(\omega,\alpha)\cdot\nabla_{\omega'}L_{val}(\omega',\alpha)
αLtrain(ω+,α)αLtrain(ω,α)2ϵ\approx\frac{\nabla_{\alpha}L_{train}(\omega^+,\alpha)-\nabla_{\alpha}L_{train}(\omega^-,\alpha)} {2\epsilon}
(ω±=ω±ϵωLval(ω,α)\omega^\pm=\omega\pm\epsilon\nabla_{\omega'}L_{val}(\omega',\alpha))

'''
architect.py
#優化alpha參數

'''
def _concat(xs):
	# 把x view成一行,然後cat成n行
	return torch.cat([x.view(-1) for x in xs]) 
	
#architect.step(input, target, input_search, target_search, lr, optimizer, unrolled=args.unrolled)
class Architect(object):

	def __init__(self, model, args):
		self.network_momentum = args.momentum
		self.network_weight_decay = args.weight_decay
		self.model = model
		# 僅優化 arch_parameters
		self.optimizer = torch.optim.Adam(self.model.arch_parameters(),
		lr=args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay)
   
	def _compute_unrolled_model(self, input, target, eta, network_optimizer):
		# 對omega參數,Ltrain loss
		# theta = theta + v + weight_decay * theta 
		# w − ξ*dwLtrain(w, α)
		loss = self.model._loss(input, target) 
		# n個參數變成n行,需更新的參數theta
		theta = _concat(self.model.parameters()).data 
		try:
			# 增加動量
			moment = _concat(network_optimizer.state[v]['momentum_buffer'] for v in self.model.parameters()).mul_(self.network_momentum)
		except:
			# 不加動量
			moment = torch.zeros_like(theta)
		dtheta = _concat(torch.autograd.grad(loss, self.model.parameters())).data + self.network_weight_decay*theta
		# w − ξ*dwLtrain(w, α) 
		unrolled_model = self._construct_model_from_theta(theta.sub(eta, moment+dtheta))
		return unrolled_model

	def step(self, input_train, target_train, input_valid, target_valid, eta, network_optimizer, unrolled):
		# 清除之前的更新參數值梯度
		self.optimizer.zero_grad()
		if unrolled:
			#用論文的提出的方法
			self._backward_step_unrolled(input_train, target_train, input_valid, target_valid, eta, network_optimizer)
		else:
			# 普通優化,交替優化,僅優化alpha,簡單求導
			self._backward_step(input_valid, target_valid)
		self.optimizer.step() # optimizer存了alpha參數的指針

	def _backward_step(self, input_valid, target_valid):
		# 反向傳播,計算梯度
		loss = self.model._loss(input_valid, target_valid)
		loss.backward()

	def _backward_step_unrolled(self, input_train, target_train, input_valid, target_valid, eta, network_optimizer):
		# 計算 w' = w − ξ*dwLtrain(w, α)
		# unrolled_model中已經計算完w'
		unrolled_model = self._compute_unrolled_model(input_train, target_train, eta, network_optimizer)
		# 計算 dαLval(w',α) 
		# 對做了一次更新後的w的unrolled_model求驗證集的損失,Lval,以用來對α進行更新
		unrolled_loss = unrolled_model._loss(input_valid, target_valid)
		unrolled_loss.backward()
		# dα Lval(w',α)
		dalpha = [v.grad for v in unrolled_model.arch_parameters()]
		# dw'Lval(w',α)
		vector = [v.grad.data for v in unrolled_model.parameters()]
		
		# 計算(dαLtrain(w+,α)-dαLtrain(w-,α))/(2*epsilon) 
		implicit_grads = self._hessian_vector_product(vector, input_train, target_train)
		
    	#  dαLval(w',α)-(dαLtrain(w+,α)-dαLtrain(w-,α))/(2*epsilon)
		for g, ig in zip(dalpha, implicit_grads):
			g.data.sub_(eta, ig.data)
		# 對alpha進行更新
		for v, g in zip(self.model.arch_parameters(), dalpha):
			if v.grad is None:
				v.grad = Variable(g.data)
			else:
				v.grad.data.copy_(g.data)

	def _construct_model_from_theta(self, theta):
		# 新建network,copy alpha參數
		model_new = self.model.new()
		model_dict = self.model.state_dict()
		
		# 按照之前的大小,copy  theta參數
		params, offset = {}, 0
		for k, v in self.model.named_parameters():
			v_length = np.prod(v.size())
			params[k] = theta[offset: offset+v_length].view(v.size())
			offset += v_length
		
		assert offset == len(theta)
		model_dict.update(params)
		model_new.load_state_dict(model_dict)
		# 返回 參數更新爲做一次反向傳播後的值 的模型
		return model_new.cuda()

  # 計算(dαLtrain(w+,α)-dαLtrain(w-,α))/(2*epsilon)   
  # w+ = w+dw'Lval(w',α)*epsilon 
  # w- = w-dw'Lval(w',α)*epsilon
	def _hessian_vector_product(self, vector, input, target, r=1e-2):
		R = r / _concat(vector).norm()
		# w+ = w+dw'Lval(w',α)*epsilon 
    	for p, v in zip(self.model.parameters(), vector):
      		p.data.add_(R, v)
      	# dαLtrain(w+,α)
    	loss = self.model._loss(input, target)
    	grads_p = torch.autograd.grad(loss, self.model.arch_parameters())
		# w- = w-dw'Lval(w',α)*epsilon
    	for p, v in zip(self.model.parameters(), vector):
      		p.data.sub_(2*R, v)
      	 # dαLtrain(w-,α)
    	loss = self.model._loss(input, target)
    	grads_n = torch.autograd.grad(loss, self.model.arch_parameters())

    	for p, v in zip(self.model.parameters(), vector):
      		p.data.add_(R, v)

    	return [(x-y).div_(2*R) for x, y in zip(grads_p, grads_n)]


'''
train_search.py
training &&  validation
'''
# training
train_acc, train_obj = train(train_queue, valid_queue, model, architect, criterion, optimizer, lr)
    
# validation
valid_acc, valid_obj = infer(valid_queue, model, criterion)

def train(train_queue, valid_queue, model, architect, criterion, optimizer, lr):
	#...
	#...
	for step, (input, target) in enumerate(train_queue):
    	model.train()
    	input = Variable(input, requires_grad=False).cuda()
    	target = Variable(target, requires_grad=False).cuda()

    	# get a random minibatch from the search queue with replacement
    	input_search, target_search = next(iter(valid_queue))
    	input_search = Variable(input_search, requires_grad=False).cuda()
    	target_search = Variable(target_search, requires_grad=False).cuda()
        
        # 第一步 優化alpha,搜索參數
        # darts是交替優化的.
    	architect.step(input, target, input_search, target_search, lr, optimizer, unrolled=args.unrolled)
        # 第二步,優化omega,網絡卷積參數
    	optimizer.zero_grad()
    	logits = model(input)
    	loss = criterion(logits, target)
    	loss.backward()
    	nn.utils.clip_grad_norm(model.parameters(), args.grad_clip)
    	optimizer.step()
    	
		prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
		objs.update(loss.data[0], n)
		top1.update(prec1.data[0], n)
		top5.update(prec5.data[0], n)
	    #...
		#...
	return top1.avg, objs.avg

def infer(valid_queue, model, criterion):
	#...
	#...
	model.eval()
	for step, (input, target) in enumerate(valid_queue):
    	input = Variable(input, volatile=True).cuda()
    	target = Variable(target, volatile=True).cuda()
    	logits = model(input)
    	loss = criterion(logits, target)

   		prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
    	n = input.size(0)
   	 	objs.update(loss.item(), n)
    	top1.update(prec1.item(), n)
    	top5.update(prec5.item(), n)
		#...
		#...
	return top1.avg, objs.avg
	

train部分

'''
train.py
#數據準備(cifar10)。
	從cifar10的訓練集和驗證集中直接取
#網絡參數
					搜索網絡   訓練網絡
	init_channels:	  16		36
	layers:		   8		20
	訓練網絡多一個auxiliary
'''

'''
train.py
子網加載
'''
from model import NetworkCIFAR as Network
genotype = eval("genotypes.%s" % args.arch)
model = Network(args.init_channels, CIFAR_CLASSES, args.layers, args.auxiliary, genotype)

'''
model.py
子網結構加載
比搜索網絡多一個auxiliary,cell也是直接加載權重

'''
class Cell(nn.Module):

  def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
    super(Cell, self).__init__()
    print(C_prev_prev, C_prev, C)

    if reduction_prev:
      self.preprocess0 = FactorizedReduce(C_prev_prev, C)
    else:
      self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
    self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
    
    #  加載genotype
    if reduction:
      op_names, indices = zip(*genotype.reduce)
      concat = genotype.reduce_concat
    else:
      op_names, indices = zip(*genotype.normal)
      concat = genotype.normal_concat
    self._compile(C, op_names, indices, concat, reduction)

  def _compile(self, C, op_names, indices, concat, reduction):
    assert len(op_names) == len(indices)
    self._steps = len(op_names) // 2
    self._concat = concat
    self.multiplier = len(concat)

    self._ops = nn.ModuleList()
    for name, index in zip(op_names, indices):
      stride = 2 if reduction and index < 2 else 1
      op = OPS[name](C, stride, True)
      self._ops += [op]
    self._indices = indices

  def forward(self, s0, s1, drop_prob):
    s0 = self.preprocess0(s0)
    s1 = self.preprocess1(s1)

    states = [s0, s1]
    for i in range(self._steps):
      h1 = states[self._indices[2*i]]
      h2 = states[self._indices[2*i+1]]
      op1 = self._ops[2*i]
      op2 = self._ops[2*i+1]
      h1 = op1(h1)
      h2 = op2(h2)
      if self.training and drop_prob > 0.:
        if not isinstance(op1, Identity):
          h1 = drop_path(h1, drop_prob)
        if not isinstance(op2, Identity):
          h2 = drop_path(h2, drop_prob)
      s = h1 + h2
      states += [s]
    return torch.cat([states[i] for i in self._concat], dim=1)


class AuxiliaryHeadCIFAR(nn.Module):

  def __init__(self, C, num_classes):
    """assuming input size 8x8"""
    super(AuxiliaryHeadCIFAR, self).__init__()
    self.features = nn.Sequential(
      nn.ReLU(inplace=True),
      nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2
      nn.Conv2d(C, 128, 1, bias=False),
      nn.BatchNorm2d(128),
      nn.ReLU(inplace=True),
      nn.Conv2d(128, 768, 2, bias=False),
      nn.BatchNorm2d(768),
      nn.ReLU(inplace=True)
    )
    self.classifier = nn.Linear(768, num_classes)

  def forward(self, x):
    x = self.features(x)
    x = self.classifier(x.view(x.size(0),-1))
    return x


'''
網絡結構
model = Network(args.init_channels, CIFAR_CLASSES, args.layers, args.auxiliary, genotype)
'''
class NetworkCIFAR(nn.Module):

	def __init__(self, C, num_classes, layers, auxiliary, genotype):
	super(NetworkCIFAR, self).__init__()
	self._layers = layers
	self._auxiliary = auxiliary
	
	stem_multiplier = 3
	C_curr = stem_multiplier*C
	# 主幹
	self.stem = nn.Sequential(
	nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
	nn.BatchNorm2d(C_curr)
	)
    C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
    self.cells = nn.ModuleList()
    reduction_prev = False
    for i in range(layers):
      if i in [layers//3, 2*layers//3]:
        C_curr *= 2
        reduction = True
      else:
        reduction = False
      cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
      reduction_prev = reduction
      self.cells += [cell]
      C_prev_prev, C_prev = C_prev, cell.multiplier*C_curr
      
      # 比搜索網絡多一個auxiliary
      if i == 2*layers//3:
        C_to_auxiliary = C_prev

	if auxiliary:
		self.auxiliary_head = AuxiliaryHeadCIFAR(C_to_auxiliary, num_classes)
		self.global_pooling = nn.AdaptiveAvgPool2d(1)
		self.classifier = nn.Linear(C_prev, num_classes)

def forward(self, input):
  logits_aux = None
  s0 = s1 = self.stem(input)
  for i, cell in enumerate(self.cells):
    s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
    if i == 2*self._layers//3:
      if self._auxiliary and self.training:
        logits_aux = self.auxiliary_head(s1)
  out = self.global_pooling(s1)
  logits = self.classifier(out.view(out.size(0),-1))
  return logits, logits_aux



'''
train.py
子網訓練
'''
for epoch in range(args.epochs):
    scheduler.step()
    # ...
    model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
    train_acc, train_obj = train(train_queue, model, criterion, optimizer)
    valid_acc, valid_obj = infer(valid_queue, model, criterion)
   # ...

def train(train_queue, model, criterion, optimizer):
	objs = utils.AvgrageMeter()
	top1 = utils.AvgrageMeter()
	top5 = utils.AvgrageMeter()
	model.train()

	for step, (input, target) in enumerate(train_queue):
		input = Variable(input).cuda()
	    target = Variable(target).cuda()
	    optimizer.zero_grad()
	    # logits_aux
	    logits, logits_aux = model(input)
	    loss = criterion(logits, target)
	    # 採用附加結構loss*權重
    	if args.auxiliary:
			loss_aux = criterion(logits_aux, target)
			loss += args.auxiliary_weight*loss_aux
    	loss.backward()
		nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
		optimizer.step()

		prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
		n = input.size(0)
		objs.update(loss.item(), n)
		top1.update(prec1.item(), n)
		top5.update(prec5.item(), n)
	return top1.avg, objs.avg


def infer(valid_queue, model, criterion):
  objs = utils.AvgrageMeter()
  top1 = utils.AvgrageMeter()
  top5 = utils.AvgrageMeter()
  model.eval()

  for step, (input, target) in enumerate(valid_queue):
    input = Variable(input, volatile=True).cuda()
    target = Variable(target, volatile=True).cuda()

    logits, _ = model(input)
    loss = criterion(logits, target)

    prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
    n = input.size(0)
    objs.update(loss.item(), n)
    top1.update(prec1.item(), n)
    top5.update(prec5.item(), n)

    if step % args.report_freq == 0:
      logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)

  return top1.avg, objs.avg
  
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章