LR 做多分類的筆記

1. 從概率的角度出發,推斷一個樣本的後驗概率爲:

其中:

4.63 式可以有比較簡潔的形式,例如:線性表達式。


2. 假定P(x| Ck) 爲正態分佈,

則 lnp(x|Ck)p(Ck) 可以表示爲線性的表達式如下:


3. 求解模型參數:


4. 本質上分類走概率模型比較靠譜,直觀上某一個地方的點密集,可以說明在該類的概率搞。使用平方誤差是,距離無法衡量到某類的距離。

但是 SVC 也用了類似的距離, 但是 SVC 只用了支持向量,且投影到高維空間。


5. sample code:

#include <string>

#include <vector>

#include <cmath>

#include <map>

#include "base/flags.h"

#include "base/string_util.h"

#include "utils/hash_tables.h"

#include "common/file/simple_line_reader.h"







DEFINE_string(train_path, "./test.txt", "trainning file");

DEFINE_double(lambda, 0.001, "the weight of regularation");

DEFINE_double(alpha, 0.1, "the learning rate");

DEFINE_int32(n, 5000, "iteration times");





// origianl data;

struct DataSample {

  std::string                      label;

  double                           predict_prob;

  utils::hash_map<std::string, double>  features;



  void AddFeature(const std::string& fn, const double& v) {

    features[fn] = v;

  }

};









// inner_class label are: [0, 1, 2,  (label_count_-1)]

class TrainningDataSet {

public:

  TrainningDataSet() {

    label_count_ = 0;

    inner_feature_map_["cb"] = 0;

    outer_feature_map_[0] = "cb";

    feature_count_ = 1;

  }



  // format like:  "A 1:0.2 2:0.4 3:77 4:0.3"

  bool LoadSamplesFromFile(const std::string& file_path) {

    file::SimpleLineReader  line_reader;

    line_reader.OpenOrDie(file_path);

    std::vector<std::string> lines;

    line_reader.ReadLines(&lines);



    for (size_t i = 0; i < lines.size(); ++i) {

      std::vector<std::string> parts;

      DataSample sample;

      SplitString(lines[i], ' ', &parts);

      

      sample.label = parts[0];

      AddLabel(parts[0]);



      for (size_t j = 1; j < parts.size(); ++j) {

        std::vector<std::string> fn_v;

        SplitString(parts[j], ':', &fn_v);

        if (fn_v.size() != 2) {

          continue;

        }

        double v = 0.0f;

        StringToDouble(fn_v[1], &v);

        sample.AddFeature(fn_v[0], v);

        AddFeature(fn_v[0]);

      }

      samples_.push_back(sample);

    }

    return true;

  }



 



  void Train() {

    AllocAuxParam();

    TrainInternal(FLAGS_n);

    Predict();

    FreeAuxParam();

  }



  void Predict() {

    double prob[100];

    // Xn

    for (auto it = samples_.begin(); it != samples_.end(); ++it) {

      for (int k = 0; k < label_count_; ++k) {

        prob[k] = w[k][0]*1.0f;

        for (auto  sit = it->features.begin(); sit != it->features.end(); ++sit) {

          prob[k] += sit->second*w[k][inner_feature_map_[sit->first]];

        }

        prob[k] = exp(prob[k]);

      }



      double total_exp = 0.0f;

      for (int k = 0; k < label_count_; ++k) {

        total_exp += prob[k];

      }



      it->predict_prob = prob[inner_label_map_[it->label]]/total_exp;

      VLOG(0) << "sampel predict [" << it->label << "]: " << it->predict_prob;

    }

  }



  void TrainInternal(int32 count) {

    for (int32 i = 0; i < count; ++i) {

      //VLOG(0) << "training iteration:  " << (i+1);



      // caculate post_prob: P(Ci | x)

      for (size_t n = 0; n < samples_.size(); ++n) {

        const DataSample& sample = samples_[n];



        for (int32 k = 0; k < label_count_; ++k) {

          // W*X,  x[0] = 1;

          post_prob[n][k] = 1.0f*w[k][0];

          for (auto it = sample.features.begin(); it != sample.features.end(); ++it) {

            std::string outter_feature_idx = it->first;

            double val = it->second;

            int32 feature_idx = inner_feature_map_[outter_feature_idx];

            post_prob[n][k] += val*w[k][feature_idx];

          }

          post_prob[n][k] = exp(post_prob[n][k]);

        }



        double exp_total = 0.0f;

        for (int32 k = 0; k < label_count_; ++k) {

          exp_total += post_prob[n][k];

        }

        for (int32 k = 0; k < label_count_; ++k) {

         post_prob[n][k] /= exp_total;

         //VLOG(0) << "P(C" << k << "|X" << n << ") = " << post_prob[n][k];

        }

      }



      // caculate gradient, (E)/(Wk)

      for (int32 k = 0; k < label_count_; ++k) {

        for (int32 d = 0; d < feature_count_; ++d) {

          grad[k][d] = 0.0f;

        }



        // iteration on every sample

        for (size_t n = 0; n < samples_.size(); ++n) {

          // iteration on every dimension

          double Tnk = GetTnk(n, k);

          double Ynk = post_prob[n][k];

          for (int32 d = 0; d < feature_count_; ++d) {

            grad[k][d] += GetXnd(n, d)*(Ynk - Tnk);

          }

        }



        //std::string w_str;

        for (int32 d = 0; d < feature_count_; ++d) {

          grad[k][d] += w[k][d]*FLAGS_lambda;

          w[k][d] -= FLAGS_alpha*grad[k][d];

          //w_str.append(outer_feature_map_[d]).append(":").append(DoubleToString(w[k][d])).append(",");

        }

        //VLOG(0) << "[w" << k << "]: " << w_str;

      }

    }

  }



  void Dump() {

    utils::hash_map<std::string, int>::iterator it;

    for (it = inner_label_map_.begin(); it != inner_label_map_.end(); ++it) {

      VLOG(0) << "lable: " << it->first << ", " << it->second;

    }



    for (it = inner_feature_map_.begin(); it != inner_feature_map_.end(); ++it) {

      VLOG(0) << "featu: " << it->first << ", " << it->second;

    }

  }



private:

  double**   w;           // w[k][d]    update:  w[k] = w[k] - alpha*(grad[k])

  double**   grad;        // grad[k][d] update:  grad[k] = (Ynk - Tnk)*Xn*lambda;

  double**   post_prob;   // post_prob[n][k] update:  p[n][k] = P(Ck | xn);





  std::vector<DataSample>  samples_;



  utils::hash_map<std::string, int>  inner_label_map_;  // 'A' -> 1    'B' -> 2

  utils::hash_map<int, std::string>  outer_label_map_;  //   1 -> 'A'    2 -> 'B'

  int32 label_count_;



  int AddLabel(const std::string& outer_label) {

    utils::hash_map<std::string, int>::iterator it = inner_label_map_.find(outer_label);

    if (it == inner_label_map_.end()) {

      inner_label_map_[outer_label] = label_count_;

      outer_label_map_[label_count_] = outer_label;

      label_count_++;

    }

    return it->second;

  }



  utils::hash_map<std::string, int>  inner_feature_map_;      // "const_bias" -> 0,  "1" -> 1,  "2" -> 2,  "url_host_big" -> feature_count

  utils::hash_map<int, std::string>  outer_feature_map_;

  int32 feature_count_;



  void AddFeature(const std::string& feature_name) {

    utils::hash_map<std::string, int>::iterator it = inner_feature_map_.find(feature_name);

    if (it == inner_feature_map_.end()) {

      inner_feature_map_[feature_name] = feature_count_;

      outer_feature_map_[feature_count_] = feature_name;

      feature_count_++;

    }

  }



  double GetXnd(const int32& n, const int32& d) {

    double ret = 0.0f;

    if (d == 0) {

      ret = 1.0f;

    } else {

      DataSample& sample = samples_[n];

      std::string& ol = outer_feature_map_[d];

      auto it = sample.features.find(ol);

      if (it != sample.features.end()) {

        ret = it->second;

      }

    }

    //VLOG(0) << "X(" << n << "," << d << ") = " << ret;

    return ret;

  }





  double GetTnk(const int32& n, const int32& k) {

    double ret = 0.0f;

    if (inner_label_map_[samples_[n].label] == k) {

      ret = 1.0f;

    }

    //VLOG(0) << "T(" << n << "," << k << ") = " << ret;

    return ret;

  }







 void FreeAuxParam() {

    for (int k = 0; k < label_count_; ++k) {

      delete w[k];

      delete grad[k];

    }

    delete w;

    delete grad;



    for (size_t n = 0; n < samples_.size(); ++n) {

      delete post_prob[n];

    }

    delete post_prob;

  }



  void AllocAuxParam() {

    // w[k][d], grad[k][d]

    w = new double*[label_count_];

    grad = new double*[label_count_];

    for (int k = 0; k < label_count_; ++k) {

      w[k] = new double[feature_count_];

      grad[k] = new double[feature_count_];

      for (int f = 0; f < feature_count_; ++f) {

        w[k][f] = 0.0f;

      }

    }



    // post_prob[n][k]

    post_prob = new double*[samples_.size()];

    for (size_t n = 0; n < samples_.size(); ++n) {

      post_prob[n] = new double[label_count_];

    }

  }

};



















int main(int argc, char* argv[]) {

  base::ParseCommandLineFlags(&argc, &argv, false);

  TrainningDataSet  tds;

  tds.LoadSamplesFromFile(FLAGS_train_path);

  tds.Dump();

  tds.Train();



  return 0;

}









測試樣例:

A 1:0.20 2:0.70

A 1:0.10 2:0.80

A 1:0.30 2:0.60

A 1:0.05 2:0.94

A 1:0.77 2:0.22

A 1:0.44 2:0.55

B 1:0.20 2:0.81

B 1:0.30 2:0.71

B 1:1.00 2:0.01

B 1:0.50 2:0.51

B 1:0.40 2:0.65

B 1:0.70 2:0.40


結果:

I0930 09:53:14.443656 32046 lr.cc:100] sampel predict [A]: 0.926048
I0930 09:53:14.443763 32046 lr.cc:100] sampel predict [A]: 0.932346
I0930 09:53:14.443836 32046 lr.cc:100] sampel predict [A]: 0.919215
I0930 09:53:14.443896 32046 lr.cc:100] sampel predict [A]: 0.592285
I0930 09:53:14.443940 32046 lr.cc:100] sampel predict [A]: 0.421599
I0930 09:53:14.443987 32046 lr.cc:100] sampel predict [A]: 0.499967
I0930 09:53:14.444035 32046 lr.cc:100] sampel predict [B]: 0.569759
I0930 09:53:14.444111 32046 lr.cc:100] sampel predict [B]: 0.593065
I0930 09:53:14.444164 32046 lr.cc:100] sampel predict [B]: 0.740223
I0930 09:53:14.444211 32046 lr.cc:100] sampel predict [B]: 0.638351
I0930 09:53:14.444258 32046 lr.cc:100] sampel predict [B]: 0.816627
I0930 09:53:14.444300 32046 lr.cc:100] sampel predict [B]: 0.955107

中間的幾個點很接近,所以 prob 不是那麼高~



發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章