深度學習乾貨學習(1)——center loss

在構建loss時pytorch常用的包中有最常見的MSE、cross entropy(logsoftmax+NLLLoss)、KL散度Loss、BCE、HingeLoss等等,詳見:https://pytorch-cn.readthedocs.io/zh/latest/package_references/torch-nn/#loss-functions

這裏主要講解一種考慮類間距離的Center Loss:
一、簡介:

center loss來自ECCV2016的一篇論文:A Discriminative Feature Learning Approach for Deep Face Recognition。
論文鏈接:http://ydwen.github.io/papers/WenECCV16.pdf

 
二、爲什麼要使用Center Loss:

    . In most of the available CNNs, the softmax loss function is used as the supervision signal to train the deep model. In order to enhance the discriminative power of the deeply learned features, this paper proposes a new supervision signal, called center loss

    the center loss simultaneously learns a center for deep features of each class and penalizes the distances between the deep features and their corresponding class centers

簡單的來說,我們在做分類(無論是image、instance、pixel level)的時候,我們不光需要學得separable的特徵,更想要這些特徵是discriminative的,這就意味着我們需要在loss上做更多的約束。

僅僅使用softmax作爲監督信號的輸出處理就只能做到seperable而不是discriminative,如下圖:

 
三、如何使學到的特徵差異化更大——Center Loss:

    Specifically, we learn a center (a vector with the same dimension as a feature) for deep features of each class.

    The CNNs are trained under the joint supervision of the softmax loss and center loss, with a hyper parameter to balance the two supervision signals.

融合Softmax Loss 與 Center Loss:

Softmax Loss (保證類之間的feature距離最大)與 Center Loss (保證類內的feature距離最小,更接近於類中心)

m是mini-batch、n是class。在Lc公式中有一個缺陷,就是Cyi是i這個樣本對應的類別yi所屬於的類中心C∈ Rd,d代表d維。

理想情況下,Cyi需要隨着學到的feature變化而實時更新,也就是要在每一次迭代中用整個數據集的feature來算每個類的中心。

但這顯然不現實,做以下兩個修改:

1、由整個訓練集更新center改爲mini-batch更改center

2、避免錯誤分類的樣本的干擾,使用scalar α 來控制center的學習率

因此求算梯度的公式如下:

即:當yi = j,也就是mini-batch中某一個sample是對應要更新的那一個類的center的時候就累加起來除以某類的個數+1。

最終loss聯立起來如上圖,λ用於平衡softmax loss與center loss,越大則區分度 越大,如下圖效果:

 
四、Center Loss的實現:

在三種我們清楚了原理,保證分類情況下的intra-class loss最小。下面講解如何在代碼和結構中實現:

pytorch的使用者可以參看:https://github.com/jxgu1016/MNIST_center_loss_pytorch

(1)網絡結構:

即在特徵層輸出(classification前最後一層)引入center loss:

(2)如果任有不明白結合algorithm理解:

(3)Code:

    class CenterLoss(nn.Module):
        def __init__(self, num_classes, feat_dim, size_average=True):
            super(CenterLoss, self).__init__()
            self.centers = nn.Parameter(torch.randn(num_classes, feat_dim))
            self.centerlossfunc = CenterlossFunc.apply
            self.feat_dim = feat_dim
            self.size_average = size_average
     
        def forward(self, label, feat):
            batch_size = feat.size(0)
            feat = feat.view(batch_size, -1)
            # To check the dim of centers and features
            if feat.size(1) != self.feat_dim:
                raise ValueError("Center's dim: {0} should be equal to input feature's dim: {1}".format(self.feat_dim,feat.size(1)))
            loss = self.centerlossfunc(feat, label, self.centers)
            loss /= (batch_size if self.size_average else 1)
            return loss
     
     
    class CenterlossFunc(Function):
        @staticmethod
        def forward(ctx, feature, label, centers):
            ctx.save_for_backward(feature, label, centers)
            centers_batch = centers.index_select(0, label.long())
            return (feature - centers_batch).pow(2).sum() / 2.0
     
        @staticmethod
        def backward(ctx, grad_output):
            feature, label, centers = ctx.saved_tensors
            centers_batch = centers.index_select(0, label.long())
            diff = centers_batch - feature
            # init every iteration
            counts = centers.new(centers.size(0)).fill_(1)
            ones = centers.new(label.size(0)).fill_(1)
            grad_centers = centers.new(centers.size()).fill_(0)
     
            counts = counts.scatter_add_(0, label.long(), ones)
            grad_centers.scatter_add_(0, label.unsqueeze(1).expand(feature.size()).long(), diff)
            grad_centers = grad_centers/counts.view(-1, 1)
            return - grad_output * diff, None, grad_centers
     
     
    def main(test_cuda=False):
        print('-'*80)
        device = torch.device("cuda" if test_cuda else "cpu")
        ct = CenterLoss(10,2).to(device)
        y = torch.Tensor([0,0,2,1]).to(device)
        feat = torch.zeros(4,2).to(device).requires_grad_()
        print (list(ct.parameters()))
        print (ct.centers.grad)
        out = ct(y,feat)
        print(out.item())
        out.backward()
        print(ct.centers.grad)
        print(feat.grad)

 
五、擴展:

center loss 與 constrastive loss 以及 triplet loss的區別在原文中也有給出,center loss相對於contrastive和triplet loss的優點顯然省去了複雜並且含糊的樣本對構造過程,接下來會對triplet loss做一個梳理。
---------------------  
作者:每天都要深度學習  
來源:CSDN  
原文:https://blog.csdn.net/Lucifer_zzq/article/details/81236174  
版權聲明:本文爲博主原創文章,轉載請附上博文鏈接!

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章