caffe源碼解讀(2)-center_loss_layer.cpp

  • center_loss公式定義

  • center_loss_layer.cpp源碼解讀

  • center loss,softmax loss在mnist數據集上的對比實驗

定義

“Center Loss: simultaneously learning a center for deep features of each class and penalizing the distances between the deep features and their corresponding class centers. 參考論文: A Discriminative Feature Learning Approach for Deep Face Recognition。 關於對center loss的理解,可參考知乎回答鏈接

公式

(1) Forward Computation

(2)Lc=12Ni=1N|xic|22

(2) Backward Computation
(3)Lcxi=xic

(3) Update Equation
(4)Δc=αNi=1N(xic)

代碼

(1) LayerSetUp

namespace caffe{
    template<typename Dtype>
    void CenterLossLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
        const vector<Blob<Dtype>*>& top){
        LossLayer<Dtype>::LayerSetUp(bottom, top);
        CHECK_EQ(bottom[0]->num(), bottom[1]->num()); 
         //2個bottom,feature = bottom[0]->cpu_data();label = bottom[1]->cpu_data()
         //1個top,top[0]->mutable_cpu_data()[0] = loss;
        int channels = bottom[0]->channels();
        int num = bottom[0]->num();
         //獲取center loss層的參數
         //loss weight 即參數lambda,用來調節center loss佔比
        alpha = this->layer_param_.center_loss_param().alpha();
        lossWeight = this->layer_param_.center_loss_param().loss_weight();
        clusterNum = this->layer_param_.center_loss_param().cluster_num();

        center_info_.Reshape(clusterNum, channels, 1, 1);
        center_loss_.Reshape(num, channels, 1, 1);
        center_update_count_.resize(clusterNum);
        //caffe_set對center_info_.mutable_cpu_data()初始化
        caffe_set(clusterNum * channels, Dtype(0.0), center_info_.mutable_cpu_data());
    }

(2) Forward前向傳播

template<typename Dtype>
    void CenterLossLayer<Dtype>::Forward_cpu(
        const vector<Blob<Dtype>*> &bottom,
        const vector<Blob<Dtype>*> &top){
        //2個bottom輸入
        const Dtype *feature = bottom[0]->cpu_data();
        const Dtype *label = bottom[1]->cpu_data();
        int num = bottom[0]->num();
        int channels = bottom[0]->channels();
        //初始化loss
        Dtype loss = 0;
        caffe_set(clusterNum * channels, Dtype(0.0), center_info_.mutable_cpu_diff());
        for(int i = 0; i < clusterNum; ++i){
            center_update_count_[i] = 1;
        }
        for(int i = 0; i < num; ++i){
            int targetLabel = label[i];
            //caffe_sub做減法:center_loss.mutable_cpu_data=feature-center_info_.cpu_data()
            //即公式中xi-c
            caffe_sub(channels, feature + i * channels,
            center_info_.cpu_data() + targetLabel * channels,
            center_loss_.mutable_cpu_data() + i * channels);
            // store the update loss and number
            caffe_add(channels, center_loss_.cpu_data() + i * channels,
            center_info_.cpu_diff() + targetLabel * channels,
            center_info_.mutable_cpu_diff() + targetLabel * channels);
         center_update_count_[targetLabel]++;
       //此處即按公式(1)計算center loss
       //並將loss作爲top輸出
            loss += caffe_cpu_dot(channels, center_loss_.cpu_data() + i * channels,
            center_loss_.cpu_data() + i * channels) * lossWeight / Dtype(2.0) / static_cast<Dtype>(num);
        }
        top[0]->mutable_cpu_data()[0] = loss;
        // update center loss.按公式(3)更新類中心:c
        for(int i = 0; i < clusterNum; ++i){
            Dtype scale = -alpha * lossWeight / Dtype(center_update_count_[i]);
            caffe_scal(channels, scale, center_info_.mutable_cpu_diff() + i * channels);
        }
        center_info_.Update();
    }

(3) Backward反向傳播

    template<typename Dtype>
    void CenterLossLayer<Dtype>::Backward_cpu(
        const vector<Blob<Dtype>*> &top,
        const vector<bool> &propagate_down,
        const vector<Blob<Dtype>*> &bottom){
        int num = bottom[0]->num();
        int channels = bottom[0]->channels();
        //center_loss_.mutable_cpu_data()=feature-center_info_.cpu_data()
        //按公式(2)計算反向傳播偏導
        caffe_scal(num * channels, lossWeight, center_loss_.mutable_cpu_data());
        Dtype *out = bottom[0]->mutable_cpu_diff();
        //center_loss_.cpu_data()拷貝到out中進行backward運算
        caffe_copy(num * channels, center_loss_.cpu_data(), out);

    }

實驗

Github 上有開源的整個項目的代碼[鏈接](https://github.com/wangwen39/center-loss),新手可以用來練手。特徵可視化可直接參考caffe主頁:http://nbviewer.jupyter.org/github/BVLC/caffe/blob/master/examples/siamese/mnist_siamese.ipynb
mnist數據集共有10個類別的手寫體數字0-9,通過對比實驗可以看出,center loss能夠很好的使類類之間的距離增大,同時使類內更加聚攏,從而達到更好的分類準確度。
① softmax loss
softmax
② center loss + softmax loss

發佈了41 篇原創文章 · 獲贊 28 · 訪問量 5萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章