從源碼解讀Large-Margin Softmax Loss for Convolutional Neural Networks

從源碼解讀Large-Margin Softmax Loss for Convolutional Neural Networks

1.論文回顧

論文地址:https://arxiv.org/pdf/1612.02295.pdf

L-softmax的主要思想是通過一個超參m對softmax+cross entropy的損失函數進行改進。一般我們把y = Wx + b, output = softmax(y), cross_entropy(output, label)這個過程統稱爲softmax loss.

從softmax到L-softmax的改進在論文中已經解釋的非常清楚了。
image

損失函數可以寫成下面的形式
image

關於角度的問題,我們需要設計一個單調遞減的函數。我的理解是由於cos函數是一個周期函數,當m*theta > pi之後,cos(theta)會進入上升階段。而很明顯,在(4)這個式子中,theta(yi)越大,我們需要對這個限制的越厲害,因此也就需要一個更小的phi值。所以作者設計瞭如下一個phi函數。
image

在實現的過程中,我們利用cos(theta)的定義和多倍角公式,得到下面的式子:
image

2.pytorch版本源碼解讀

L-softmax有多種框架的實現版本,作者使用的是caffe,本文不詳細介紹反傳過程,因此選擇pytorch實現版本進行解讀。pytorch版本實現地址:https://github.com/jihunchoi/lsoftmax-pytorch/blob/master/lsoftmax.py。

實現代碼

	import math
	
	import torch
	from torch import nn
	from torch.autograd import Variable
	
	from scipy.special import binom
	
	
	class LSoftmaxLinear(nn.Module):
	
	    def __init__(self, input_dim, output_dim, margin):
	        super().__init__()
	        self.input_dim = input_dim
	        self.output_dim = output_dim
	        self.margin = margin
	
	        self.weight = nn.Parameter(torch.FloatTensor(input_dim, output_dim))
	
	        self.divisor = math.pi / self.margin
	        '''
	        #這個是係數,對應(7)式中前面的係數。使用一個二項分佈的函數,產生係數,並且每隔1個取一個數字。假設m=3,則有[1, 3, 3, 1], 隔開取之後是[1, 3].m = 4則有[1, 4, 6, 4, 1],隔開取之後是[1, 6, 1],與7中的係數是對應的關係
	        '''
	        self.coeffs = binom(margin, range(0, margin + 1, 2))
	        '''
	        這個是cos項的指數,從m到0,每次減去2
	        '''
	        self.cos_exps = range(self.margin, -1, -2)
	        '''
	        這個是sin平方項的指數,從1到n,與cos項的指數對應
	        '''
	        self.sin_sq_exps = range(len(self.cos_exps))
	        '''
	        這個是符號 1 -1 1 -1 1 -1....
	        '''
	        self.signs = [1]
	        for i in range(1, len(self.sin_sq_exps)):
	            self.signs.append(self.signs[-1] * -1)
	
	    def reset_parameters(self):
	        nn.init.kaiming_normal(self.weight.data.t())
	
	    def find_k(self, cos):
	        acos = cos.acos()
	        k = (acos / self.divisor).floor().detach()
	        return k
	
	    def forward(self, input, target=None):
	        '''
	        input: N,D 其中N是batch size
	        target是(N,)的label
	        '''
	        if self.training:
	            assert target is not None
	            '''
	            y = Wx這樣得到輸出邏輯, logit的維度應該是(N, C), C是輸出類別數
	            '''
	            logit = input.matmul(self.weight)
	            batch_size = logit.size(0)# N
	            '''
	            通過這個操作把N, C矩陣中的yi全部都取出來,形成一個一維向量 (N , )
	            '''
	            logit_target = logit[range(batch_size), target]
	            '''
	            求L2範數
	            '''
	            weight_target_norm = self.weight[:, target].norm(p=2, dim=0)
	            input_norm = input.norm(p=2, dim=1)
	            # norm_target_prod: (batch_size,)
	            norm_target_prod = weight_target_norm * input_norm
	            # cos_target: (batch_size,)
	            '''
	            這裏就得到了cos(theta)和sin(theta)的二次方
	            '''
	            cos_target = logit_target / (norm_target_prod + 1e-10)
	            sin_sq_target = 1 - cos_target**2
	
	            num_ns = self.margin//2 + 1# m = 4時爲3, m = 3時爲2 ,讓這個數爲n
	            # coeffs, cos_powers, sin_sq_powers, signs: (num_ns,)注意這裏的shape
	            coeffs = Variable(input.data.new(self.coeffs))
	            cos_exps = Variable(input.data.new(self.cos_exps))
	            sin_sq_exps = Variable(input.data.new(self.sin_sq_exps))
	            signs = Variable(input.data.new(self.signs))
	            
	            '''
	            (N, 1) ** (1, n)-->(N, n)這個矩陣是batch中每個example的cos
	            '''
	            cos_terms = cos_target.unsqueeze(1) ** cos_exps.unsqueeze(0)
	            '''
	            同上,這個矩陣是batch中每個example的sin平方
	            '''
	            sin_sq_terms = (sin_sq_target.unsqueeze(1)
	                            ** sin_sq_exps.unsqueeze(0))
	
	            '''
	            符號*係數*cos*sin平方  (1, n) * (1, n) * (N, n) * (N , n)
	            '''
	            cosm_terms = (signs.unsqueeze(0) * coeffs.unsqueeze(0)
	                          * cos_terms * sin_sq_terms)
	            # 各個example各自求和,得到 cos (m * theta)
	            cosm = cosm_terms.sum(1)
	            # 尋找k值
	            k = self.find_k(cos_target)
	            # 根據k值計算||W||*||x||*cos (phi)
	            ls_target = norm_target_prod * (((-1)**k * cosm) - 2*k)
	            # 把計算出來的值代替原來的softmax中yi的位置,返回之後通過cross entropy計算就得到了Lsoftmax
	            logit[range(batch_size), target] = ls_target
	            return logit
	        else:
	            assert target is None
	            return input.matmul(self.weight)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章