一.简介
计算两个系列数据之间的相关性是“统计”中的常见操作。spark.ml 提供了很多系列中的灵活性,计算两两相关性。目前支持的相关方法是Pearson和Spearman的相关。
Correlation 使用指定的方法为向量的输入数据集计算相关矩阵。输出将是一个DataFrame,其中包含向量列的相关矩阵。
二.代码实战
import spark.implicits._
val data = Seq(
Vectors.sparse(4, Seq((0, 7.0), (3, -2.0))), // 稀疏向量,等价于dense(7.0, 0.0, 0.0, -2.0)
Vectors.dense(0.0, 5.0, 3.0, 3.0), // 稠密向量
Vectors.dense(0.0, 7.0, 5.0, 8.0),
Vectors.sparse(4, Seq((0, 9.0), (3, 1.0)))
)
/**
* NaN : not a number
*/
val df = data.map(Tuple1.apply).toDF("features")
df.show(false)
val Row(coeff : Matrix) = Correlation.corr(df, "features").head // 使用默认的Pearson
println(coeff)
/**
* 对于Spearman,是排名相关性,我们需要为每个列创建一个RDD [Double]并对其进行排序,以便检索排名,
* 然后将这些列重新连接到RDD [Vector]中,这是相当昂贵的。 在使用`method =“ spearman”`调用corr之前,
* 请缓存输入数据集,以避免重新计算公共谱系。
*/
val Row(coeff2 : Matrix) = Correlation.corr(df, "features", "spearman").head // 指定相关类型
println(coeff2)
spark.stop()
}
}
完整代码及相关Spark程序查看Github:Spark代码实战
三.执行结果
测试数据:
生成的相关性矩阵:
1.使用默认pearson
2.使用spearman
- 根据生成的相关性矩阵可知,两者的差距不大,一般情况下使用默认的相关性方法参加即可。
- 第二和第三个向量相关性最高,其次是第三个向量,第一个向量相关性较低,原因应该是该向量存在负数。在进行相关性之前可以考虑使用归一化以降低此类影响。
四.源码分析
@Since("2.2.0")
@Experimental
object Correlation {
@Since("2.2.0")
def corr(dataset: Dataset[_], column: String, method: String): DataFrame = {
val rdd = dataset.select(column).rdd.map {
case Row(v: Vector) => OldVectors.fromML(v)
}
val oldM = OldStatistics.corr(rdd, method)
val name = s"$method($column)"
val schema = StructType(Array(StructField(name, SQLDataTypes.MatrixType, nullable = false)))
dataset.sparkSession.createDataFrame(Seq(Row(oldM.asML)).asJava, schema)
}
/**
* Compute the Pearson correlation matrix for the input Dataset of Vectors.
*/
@Since("2.2.0")
def corr(dataset: Dataset[_], column: String): DataFrame = {
corr(dataset, column, "pearson")
}
}
从源代码上可以看出默认是使用pearson作为相关性计算的,而且具体计算的逻辑是在OldStatistics.corr(rdd, method)中实现的。
根据参数的详细注释可知:
- 第一个参数:一个dataset或dataframe。
- 第二个参数:列的名称,需要为其计算相关系数。 必须是数据集的一列,并且必须包含Vector对象。
- 第三个参数:指定用于计算相关性的方法。包括
pearson
(default),spearman
。 - 返回值:包含向量列的相关矩阵的dataframe。 该dataframe包含单行和单列名称。
- 违法参数异常IllegalArgumentException:如果该列不是数据集中的有效列,或者该列的内容不是Vector类型。