【手撕 - 自然語言處理】手撕 FastText 源碼(01)分類器的預測過程

作者:LogM

本文原載於 https://segmentfault.com/u/logm/articles ,不允許轉載~

1. 源碼來源

FastText 源碼:https://github.com/facebookre...

本文對應的源碼版本:Commits on Jun 27 2019, 979d8a9ac99c731d653843890c2364ade0f7d9d3

FastText 論文:

[1] P. Bojanowski, E. Grave, A. Joulin, T. Mikolov, Enriching Word Vectors with Subword Information

[2] A. Joulin, E. Grave, P. Bojanowski, T. Mikolov, Bag of Tricks for Efficient Text Classification

2. 概述

FastText 的論文寫的比較簡單,有些細節不明白,網上也查不到,所幸直接撕源碼。

FastText 的"分類器"功能是用的最多的,所以先從"分類器的predict"開始挖。

3. 開撕

先看程序入口的 main 函數,ok,是調用了 predict 函數。

// 文件:src/main.cc
// 行數:403
int main(int argc, char** argv) {
  std::vector<std::string> args(argv, argv + argc);
  if (args.size() < 2) {
    printUsage();
    exit(EXIT_FAILURE);
  }
  std::string command(args[1]);
  if (command == "skipgram" || command == "cbow" || command == "supervised") {
    train(args);           
  } else if (command == "test" || command == "test-label") {
    test(args);
  } else if (command == "quantize") {
    quantize(args);
  } else if (command == "print-word-vectors") {
    printWordVectors(args);
  } else if (command == "print-sentence-vectors") {
    printSentenceVectors(args);
  } else if (command == "print-ngrams") {
    printNgrams(args);
  } else if (command == "nn") {
    nn(args);
  } else if (command == "analogies") {
    analogies(args);
  } else if (command == "predict" || command == "predict-prob") {
    predict(args);       // 這句是我們想要的
  } else if (command == "dump") {
    dump(args);
  } else {
    printUsage();
    exit(EXIT_FAILURE);
  }
  return 0;
}

再看 predict 函數,預處理的代碼不用管,直接看 predict 的那行,調用了 FastText::predictLine。這裏注意下,這是個 while 循環,所以
FastText::predictLine 這個函數每次只處理一行。

// 文件:src/main.cc
// 行數:205

void predict(const std::vector<std::string>& args) {
  if (args.size() < 4 || args.size() > 6) {
    printPredictUsage();
    exit(EXIT_FAILURE);
  }
  int32_t k = 1;
  real threshold = 0.0;
  if (args.size() > 4) {
    k = std::stoi(args[4]);
    if (args.size() == 6) {
      threshold = std::stof(args[5]);
    }
  }

  bool printProb = args[1] == "predict-prob";
  FastText fasttext;
  fasttext.loadModel(std::string(args[2]));

  std::ifstream ifs;
  std::string infile(args[3]);
  bool inputIsStdIn = infile == "-";
  if (!inputIsStdIn) {
    ifs.open(infile);
    if (!inputIsStdIn && !ifs.is_open()) {
      std::cerr << "Input file cannot be opened!" << std::endl;
      exit(EXIT_FAILURE);
    }
  }
  std::istream& in = inputIsStdIn ? std::cin : ifs;
  std::vector<std::pair<real, std::string>> predictions;
  while (fasttext.predictLine(in, predictions, k, threshold)) {     // 這句是重點
    printPredictions(predictions, printProb, false);
  }
  if (ifs.is_open()) {
    ifs.close();
  }

  exit(0);
}

再看 FastText::predictLine,注意這邊有兩個重點。

// 文件:src/fasttext.cc
// 行數:451
bool FastText::predictLine(
    std::istream& in,
    std::vector<std::pair<real, std::string>>& predictions,
    int32_t k,
    real threshold) const {
  predictions.clear();
  if (in.peek() == EOF) {
    return false;
  }

  std::vector<int32_t> words, labels;
  dict_->getLine(in, words, labels);                // 這句是第一個重點
  Predictions linePredictions;
  predict(k, words, linePredictions, threshold);    // 這句是第二個重點
  for (const auto& p : linePredictions) {
    predictions.push_back(
        std::make_pair(std::exp(p.first), dict_->getLabel(p.second)));
  }

  return true;
}

先看第一個重點,getLine 函數其實是 Dictionary::getLine,定義在src/dictionary.cc

這段代碼的乾貨度還是很高的,裏面有兩個重點,Dictionary::addSubwordsDictionary::addWordNgrams,以後會講。這邊只要知道整個函數把讀到的這一行的每個Id(包括詞語的id,SubWords的Id,WordNgram的Id),存到了數組 words 中。

// 文件:src/dictionary.cc
// 行數:378
int32_t Dictionary::getLine(
    std::istream& in,
    std::vector<int32_t>& words,
    std::vector<int32_t>& labels) const {
  std::vector<int32_t> word_hashes;
  std::string token;
  int32_t ntokens = 0;

  reset(in);
  words.clear();
  labels.clear();
  while (readWord(in, token)) {     // `token` 是讀到的一個詞語,如果讀到一行的行尾,則返回`EOF`
    uint32_t h = hash(token);       // 找到這個詞語位於哪個hash桶
    int32_t wid = getId(token, h);      // 在hash桶中找到這個詞語的Id,如果負數就是沒找到對應的Id
    entry_type type = wid < 0 ? getType(token) : getType(wid);   // 如果沒找到對應Id,則有可能是label,`getType`裏會處理

    ntokens++;
    if (type == entry_type::word) {
      addSubwords(words, token, wid);   // 重點1,以後會講
      word_hashes.push_back(h);
    } else if (type == entry_type::label && wid >= 0) {
      labels.push_back(wid - nwords_);
    }
    if (token == EOS) {
      break;
    }
  }
  addWordNgrams(words, word_hashes, args_->wordNgrams);  // 重點2,以後會講
  return ntokens;
}

再來看第二個重點, FastText::predict 函數,重點是 Model::predict 函數。

// 文件:src/fasttext.cc
// 行數:437
void FastText::predict(
    int32_t k,
    const std::vector<int32_t>& words,
    Predictions& predictions,
    real threshold) const {
  if (words.empty()) {
    return;
  }
  Model::State state(args_->dim, dict_->nlabels(), 0);
  if (args_->model != model_name::sup) {
    throw std::invalid_argument("Model needs to be supervised for prediction!");
  }
  model_->predict(words, k, threshold, predictions, state);       // 這句是重點
}

來到 Model::predict,有兩個重點.

其中 Loss::predict 是將 hidden 層的輸出結果進行 softmax 後得到最終概率最大的k個類別,"分類器的predict" 用的是經典的softmax,所以代碼也比較簡單。而如果是"分類器的train" 則涉及到 Hierarchical SoftmaxLossNegativeSamplingLoss 等一些加速手段,比較複雜,以後有機會再講。

// 文件:src/model.cc
// 行數:53
void Model::predict(
    const std::vector<int32_t>& input,
    int32_t k,
    real threshold,
    Predictions& heap,
    State& state) const {
  if (k == Model::kUnlimitedPredictions) {
    k = wo_->size(0); // output size
  } else if (k <= 0) {
    throw std::invalid_argument("k needs to be 1 or higher!");
  }
  heap.reserve(k + 1);
  computeHidden(input, state);      // 重點1

  loss_->predict(k, threshold, heap, state);    // 重點2,以後再講
}

我們再來看另一個重點,Model::computeHidden 函數。

Model::computeHidden 函數理解起來比較簡單,注意這裏的 input 就是前面的 words,是一系列id組成的數組(包括詞語的id,SubWords的Id,WordNgram的Id),把這些求和,然後取平均。

當然有些小夥伴可能有點疑問,Vector::addRow 爲什麼是求和,這個以後再講吧。

// 文件:src/model.cc
// 行數:43
void Model::computeHidden(const std::vector<int32_t>& input, State& state)
    const {
  Vector& hidden = state.hidden;
  hidden.zero();
  for (auto it = input.cbegin(); it != input.cend(); ++it) {
    hidden.addRow(*wi_, *it);           // 求和
  }
  hidden.mul(1.0 / input.size());       // 然後取平均
}

4. 總結

至此,FastText裏面的"分類器的predict"的大致流程講完了,其他的,如"分類器的train"和"詞向量"的源碼也是類似的方法來閱讀。

這裏面有幾段代碼沒有詳細敘述:Dictionary::addSubwordsDictionary::addWordNgramsVector::addRow以及訓練時softmax的加速,先把坑留着,以後有時間再填。

發佈了52 篇原創文章 · 獲贊 19 · 訪問量 5萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章