Spark MLlib源碼分析—TFIDF源碼詳解

以下代碼是我依據SparkMLlib(版本1.6)
1、HashingTF 是使用哈希表來存儲分詞,並計算分詞頻數(TF),生成HashMap表。在Map中,K爲分詞對應索引號,V爲分詞的頻數。在聲明HashingTF 時,需要設置numFeatures,該屬性實爲設置哈希表的大小;如果設置numFeatures過小,則在存儲分詞時會出現重疊現象,所以不要設置太小,一般情況下設置爲30w~50w之間。
2、IDF是計算每個分詞出現在文章中的次數,並計算log值。在聲明IDF時,可以設置minDocFreq,即過濾掉出現文章數小於minDocFreq的分詞。
3、IDFModel 主要是計算TF*IDF,另外IDFModel也可以將IDF數據保存下來(即模型的保存),在測試語料時,只需要計算測試語料中每個分詞的在該篇文章中的詞頻TF,就可以計算TFIDF。

package org.apache.spark.mllib.feature
class HashingTF(val numFeatures: Int) extends Serializable {
  def this() = this(1 << 20)

  def nonNegativeMod(x: Int, mod: Int): Int = { //根據 numFeatures 設置的哈希表容量,來設定索引號
    val rawMod = x % mod
    rawMod + (if (rawMod < 0) mod else 0)
  }
  def indexOf(term: Any): Int = nonNegativeMod(term.##, numFeatures) //根據分詞來生成索引號

  def transform(document: Iterable[_]): Vector = {
    //每篇文章一個hash表,記錄每篇文章中的詞頻
    val termFrequencies = mutable.HashMap.empty[Int, Double]
    document.foreach { term =>
      val i = indexOf(term)
      //map中的getOrElse(i, 0.0)函數表示如果找到i位置的值就返回,否則就默認爲0.0
      termFrequencies.put(i, termFrequencies.getOrElse(i, 0.0) + 1.0)//注意這裏有加1計數操作
    }
    Vectors.sparse(numFeatures, termFrequencies.toSeq)
  }
  def transform[D <: Iterable[_]](dataset: RDD[D]): RDD[Vector] = {
    dataset.map(this.transform)
  }
}

class IDF(val minDocFreq: Int){
  def this() = this(0) //默認minDocFreq爲0,用來過濾文章出現次數過少的分詞

  def fit(dataset: RDD[Vector]): IDFModel = {
    val idf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator(minDocFreq = minDocFreq))(
      seqOp = (df, v) => df.add(v), 
      combOp = (df1, df2) => df1.merge(df2)
    ).idf()
    new IDFModel(idf)
  }
}

private object IDF {

  /** Document frequency aggregator. */
  class DocumentFrequencyAggregator(val minDocFreq: Int) extends Serializable {

    /** number of documents */
    private var m = 0L
    /** document frequency vector */
    private var df: BDV[Long] = _
    def this() = this(0)
    private def isEmpty: Boolean = m == 0L

    def add(doc: Vector): this.type = { //add -> 計算分詞在每個分區中的文章頻率
      if (isEmpty) {
        df = BDV.zeros(doc.size)
      }
      doc match {
        case SparseVector(size, indices, values) =>
          val nnz = indices.size
          var k = 0
          while (k < nnz) {
            if (values(k) > 0) {   //表示分詞values(k)在該篇文章中出現過
              df(indices(k)) += 1L  //計數分詞indices(k)出現在多少篇文章中
            }
            k += 1
          }
        case DenseVector(values) =>
          val n = values.size
          var j = 0
          while (j < n) {
            if (values(j) > 0.0) {  //作用和上面一樣,只是在spark中有DenseVector 和 SparseVector兩種向量的區別。
              df(j) += 1L
            }
            j += 1
          }
        case other =>
          throw new UnsupportedOperationException(
            s"Only sparse and dense vectors are supported but got ${other.getClass}.")
      }
      m += 1L
      this
    }

    /** Merges another. */
    def merge(other: DocumentFrequencyAggregator): this.type = { //將各個分區聚合到一起
      if (!other.isEmpty) {
        m += other.m
        if (df == null) {
          df = other.df.copy
        } else {
          df += other.df
        }
      }
      this
    }

    /** 返回當前IDF的向量 */
    def idf(): Vector = {
      if (isEmpty) {
        throw new IllegalStateException("Haven't seen any document yet.")
      }
      val n = df.length
      val inv = new Array[Double](n)
      var j = 0
      while (j < n) {
        if (df(j) >= minDocFreq) {
          inv(j) = math.log((m + 1.0) / (df(j) + 1.0)) //計算IDF —— log(D/d(j))
        }
        j += 1
      }
      Vectors.dense(inv)
    }
  }
}

class IDFModel(val idf: Vector) extends Serializable {   // idf 裏面存儲的是IDF向量
  def transform(dataset: RDD[Vector]): RDD[Vector] = {  //dataset裏面存儲的是TF向量
    val bcIdf = dataset.context.broadcast(idf)
    dataset.mapPartitions(iter => iter.map(v => IDFModel.transform(bcIdf.value, v))) 
  }
  def transform(v: Vector): Vector = IDFModel.transform(idf, v)
}

private object IDFModel {
  def transform(idf: Vector, v: Vector): Vector = { // 這裏就是  idf * v (v是TF向量)
    val n = v.size
    v match {
      case SparseVector(size, indices, values) =>
        val nnz = indices.size
        val newValues = new Array[Double](nnz)
        var k = 0
        while (k < nnz) {
          newValues(k) = values(k) * idf(indices(k))  //SparseVector 向量下 TF * IDF
          k += 1
        }
        Vectors.sparse(n, indices, newValues)
      case DenseVector(values) =>
        val newValues = new Array[Double](n)
        var j = 0
        while (j < n) {
          newValues(j) = values(j) * idf(j)  //DenseVector 向量下 TF * IDF
          j += 1
        }
        Vectors.dense(newValues)
      case other =>
        throw new UnsupportedOperationException(
          s"Only sparse and dense vectors are supported but got ${other.getClass}.")
    }
  }
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章