Spark ML基本算法【Correlation相关性】

一.简介

计算两个系列数据之间的相关性是“统计”中的常见操作。spark.ml 提供了很多系列中的灵活性,计算两两相关性。目前支持的相关方法是Pearson和Spearman的相关。
Correlation 使用指定的方法为向量的输入数据集计算相关矩阵。输出将是一个DataFrame,其中包含向量列的相关矩阵。

二.代码实战

    import spark.implicits._

    val data = Seq(
      Vectors.sparse(4, Seq((0, 7.0), (3, -2.0))), // 稀疏向量,等价于dense(7.0, 0.0, 0.0, -2.0)
      Vectors.dense(0.0, 5.0, 3.0, 3.0), // 稠密向量
      Vectors.dense(0.0, 7.0, 5.0, 8.0),
      Vectors.sparse(4, Seq((0, 9.0), (3, 1.0)))
    )

    /**
      * NaN : not a number
      */
    val df = data.map(Tuple1.apply).toDF("features")
    df.show(false)
    val Row(coeff : Matrix) = Correlation.corr(df, "features").head // 使用默认的Pearson
    println(coeff)
    
    /**
      * 对于Spearman,是排名相关性,我们需要为每个列创建一个RDD [Double]并对其进行排序,以便检索排名,
      * 然后将这些列重新连接到RDD [Vector]中,这是相当昂贵的。 在使用`method =“ spearman”`调用corr之前,
      * 请缓存输入数据集,以避免重新计算公共谱系。
      */
    val Row(coeff2 : Matrix) = Correlation.corr(df, "features", "spearman").head // 指定相关类型
    println(coeff2)

    spark.stop()
  }
}

完整代码及相关Spark程序查看Github:Spark代码实战

三.执行结果

测试数据:
在这里插入图片描述
生成的相关性矩阵:
1.使用默认pearson
在这里插入图片描述
2.使用spearman
在这里插入图片描述

  • 根据生成的相关性矩阵可知,两者的差距不大,一般情况下使用默认的相关性方法参加即可。
  • 第二和第三个向量相关性最高,其次是第三个向量,第一个向量相关性较低,原因应该是该向量存在负数。在进行相关性之前可以考虑使用归一化以降低此类影响。

四.源码分析

@Since("2.2.0")
@Experimental
object Correlation {
  @Since("2.2.0")
  def corr(dataset: Dataset[_], column: String, method: String): DataFrame = {
    val rdd = dataset.select(column).rdd.map {
      case Row(v: Vector) => OldVectors.fromML(v)
    }
    val oldM = OldStatistics.corr(rdd, method)
    val name = s"$method($column)"
    val schema = StructType(Array(StructField(name, SQLDataTypes.MatrixType, nullable = false)))
    dataset.sparkSession.createDataFrame(Seq(Row(oldM.asML)).asJava, schema)
  }

  /**
   * Compute the Pearson correlation matrix for the input Dataset of Vectors.
   */
  @Since("2.2.0")
  def corr(dataset: Dataset[_], column: String): DataFrame = {
    corr(dataset, column, "pearson")
  }
}

从源代码上可以看出默认是使用pearson作为相关性计算的,而且具体计算的逻辑是在OldStatistics.corr(rdd, method)中实现的。
根据参数的详细注释可知:

  1. 第一个参数:一个dataset或dataframe。
  2. 第二个参数:列的名称,需要为其计算相关系数。 必须是数据集的一列,并且必须包含Vector对象。
  3. 第三个参数:指定用于计算相关性的方法。包括pearson (default), spearman
  4. 返回值:包含向量列的相关矩阵的dataframe。 该dataframe包含单行和单列名称。
  5. 违法参数异常IllegalArgumentException:如果该列不是数据集中的有效列,或者该列的内容不是Vector类型。
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章