spark机器学习 K-means聚类算法

推荐大家去看原文博主的文章,条理清晰阅读方便,转载是为了方便以后个人查阅

https://blog.csdn.net/weixin_43283487/article/details/89033599

1.聚类和分类区别

K-means聚类算法中K表示将数据聚类成K个簇,means表示每个聚类中数据的均值作为该簇的中心,也称为质心。K-means聚类试图将相似的对象归为同一个簇,将不相似的对象归为不同簇,这里需要一种对数据衡量相似度的计算方法,K-means算法是典型的基于距离的聚类算法,采用距离作为相似度的评价指标,默认以欧式距离作为相似度测度,即两个对象的距离越近,其相似度就越大。

聚类和分类最大的不同在于,分类的目标是事先已知的,而聚类则不一样,聚类事先不知道目标变量是什么,类别没有像分类那样被预先定义出来,也就是聚类分组不需要提前被告知所划分的组应该是什么样的,因为我们甚至可能都不知道我们再寻找什么,所以聚类是用于知识发现而不是预测,所以,聚类有时也叫无监督学习。

2. K-means原理

假设有一批关于计算机科学和数学统计相关的人才,这批人才中计算机人才、机器学习人才、数学人才三类,那么该如何将这批数据进行聚类?

我们可以直观的感觉到应该如下分类:

但问题是计算机不会直观的去观察数据,首先将这批数据向量化,K-means聚类会随机在这些点中找到三个点,然后计算所有的样本到当前三个点的距离大小,判断样本点与当前三个点哪个距离比较近,当前样本就属于那个类。

当经过一次计算之后,就是经过了一次迭代过程,当完成一次迭代后,就分出来是三个类别(簇),每个类别中都有质心。然后进行下一次迭代,继续计算所有点到三个类别中心点的距离,按照每个样本点与哪个质心距离最近就属于哪个簇,以此类推,继续迭代。当本次计算的中心点的距离较上次计算的中心点的位置不再变化,那么停止迭代。


K值的选择一般可以根据问题的内容来确定,也可以根据肘部法来确定。如图:横轴表示K值的选择,纵轴表示对应的K值下所有聚类的平均畸变程度。

每个类的畸变程度是每个类别下每个样本到质心的位置距离的平方和。类内部成员越是紧凑,那么类的畸变程度越低,这个类内部相似性越大,聚类也就越好。如图,在k=1请况下,相比k=2情况下,类的平均畸变程度变化大,说明,k=2的情况类的紧凑程度比k=1情况下要紧凑的多。同理,发现当k=3之后,随着k的增大,类的平均畸变程度变化不大,说明k=3是比较好的k值。k>3后类的平均畸变程度变化不大,聚类的个数越多,有可能类与类之间的相似度越大,类的内部反而没有相似度,这种聚类也是不好的。举个极端的例子,有1000个数据,分成1000个类,那么类的平均畸变程度是0,那么每个数据都是一类,类与类之间的相似度大,类内部没有相似性。
K-means算法的思想就是对空间K个点为中心进行聚类,对靠近他们的对象进行归类,通过迭代的方法,逐次更新聚类中心(质心)的值,直到得到最好的聚类结果。K-means过程:

首先选择k个类别的中心点
对任意一个样本,求其到各类中心的距离,将该样本归到距离最短的中心所在的类
聚好类后,重新计算每个聚类的中心点位置
重复2,3步骤迭代,直到k个类中心点的位置不变,或者达到一定的迭代次数,则迭代结束,否则继续迭代

3. K-means++算法

K-means算法假设聚类为3类,开始选取每个类的中心点的时候是随机选取,有可能三个点选取的位置非常近,导致后面每次聚类重新求各类中心的迭代次数增加。K-means++在选取第一个聚类中心点的时候也是随机选取,当选取第二个中心点的时候,距离当前已经选择的聚类中心点的距离越远的点会有更高的概率被选中,假设已经选取n个点,当选取第n+1个聚类中心时,距离当前n个聚类中心点越远的点越会被选中,这种思想是聚类中心的点离的越远越好,这样就大大降低的找到最终聚类各个中心点的迭代次数,提高了效率。

4.spark训练模型

val model = new KMeans().
//设置聚类的类数
setK(numClusters).
//设置找中心点最大的迭代次数
setMaxIterations(numIterations).
run(parsedData)



object KMeans {

  def main(args: Array[String]) {
    //1 构建Spark对象
    val conf = new SparkConf().setAppName("KMeans").setMaster("local")
    val sc = new SparkContext(conf)

    // 读取样本数据1,格式为LIBSVM format
    val data = sc.textFile("kmeans_data.txt")
val parsedData = data.map(s => Vectors.dense(s.split(' ')
.map(_.toDouble))).cache()
    val numClusters = 4
    val numIterations = 100
    val model = new KMeans().
      //设置聚类的类数
      setK(numClusters).
      //设置找中心点最大的迭代次数
      setMaxIterations(numIterations).
      run(parsedData)
      
    //四个中心点的座标
    val centers = model.clusterCenters
    val k = model.k
    centers.foreach(println)
    println(k)
    //保存模型
//    model.save(sc, "./Kmeans_model")
    //加载模型
    val sameModel = KMeansModel.load(sc, "./Kmeans_model")
    println(sameModel.predict(Vectors.dense(1,1,1)))
    val sqlContext = new SQLContext(sc)
    sqlContext.read.parquet("./Kmeans_model/data").show()
    
  }
}

给Kmeans指定中心点座标:
 

object KMeans2 {

  def main(args: Array[String]) {
    val conf = new SparkConf().setAppName("KMeans2").setMaster("local")
    val sc = new SparkContext(conf)

    val rdd = sc.parallelize(List(
      Vectors.dense(Array(-0.1, 0.0, 0.0)),
      Vectors.dense(Array(9.0, 9.0, 9.0)),
      Vectors.dense(Array(3.0, 2.0, 1.0))))
      
    //指定文件 kmeans_data.txt 中的六个点为中心点座标。
    val centroids: Array[Vector] = sc.textFile("kmeans_data.txt")
        .map(_.split(" ").map(_.toDouble))
        .map(Vectors.dense(_))
        .collect()
    val model = new KMeansModel(clusterCenters=centroids)
    println("聚类个数 = "+model.k)
    //模型中心点
    model.clusterCenters.foreach { println }
    //预测指定的三条数据
    val result = model.predict(rdd)
    result.collect().foreach(println(_))
  }
}

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章