spark機器學習 K-means聚類算法

推薦大家去看原文博主的文章,條理清晰閱讀方便,轉載是爲了方便以後個人查閱

https://blog.csdn.net/weixin_43283487/article/details/89033599

1.聚類和分類區別

K-means聚類算法中K表示將數據聚類成K個簇,means表示每個聚類中數據的均值作爲該簇的中心,也稱爲質心。K-means聚類試圖將相似的對象歸爲同一個簇,將不相似的對象歸爲不同簇,這裏需要一種對數據衡量相似度的計算方法,K-means算法是典型的基於距離的聚類算法,採用距離作爲相似度的評價指標,默認以歐式距離作爲相似度測度,即兩個對象的距離越近,其相似度就越大。

聚類和分類最大的不同在於,分類的目標是事先已知的,而聚類則不一樣,聚類事先不知道目標變量是什麼,類別沒有像分類那樣被預先定義出來,也就是聚類分組不需要提前被告知所劃分的組應該是什麼樣的,因爲我們甚至可能都不知道我們再尋找什麼,所以聚類是用於知識發現而不是預測,所以,聚類有時也叫無監督學習。

2. K-means原理

假設有一批關於計算機科學和數學統計相關的人才,這批人才中計算機人才、機器學習人才、數學人才三類,那麼該如何將這批數據進行聚類?

我們可以直觀的感覺到應該如下分類:

但問題是計算機不會直觀的去觀察數據,首先將這批數據向量化,K-means聚類會隨機在這些點中找到三個點,然後計算所有的樣本到當前三個點的距離大小,判斷樣本點與當前三個點哪個距離比較近,當前樣本就屬於那個類。

當經過一次計算之後,就是經過了一次迭代過程,當完成一次迭代後,就分出來是三個類別(簇),每個類別中都有質心。然後進行下一次迭代,繼續計算所有點到三個類別中心點的距離,按照每個樣本點與哪個質心距離最近就屬於哪個簇,以此類推,繼續迭代。當本次計算的中心點的距離較上次計算的中心點的位置不再變化,那麼停止迭代。


K值的選擇一般可以根據問題的內容來確定,也可以根據肘部法來確定。如圖:橫軸表示K值的選擇,縱軸表示對應的K值下所有聚類的平均畸變程度。

每個類的畸變程度是每個類別下每個樣本到質心的位置距離的平方和。類內部成員越是緊湊,那麼類的畸變程度越低,這個類內部相似性越大,聚類也就越好。如圖,在k=1請況下,相比k=2情況下,類的平均畸變程度變化大,說明,k=2的情況類的緊湊程度比k=1情況下要緊湊的多。同理,發現當k=3之後,隨着k的增大,類的平均畸變程度變化不大,說明k=3是比較好的k值。k>3後類的平均畸變程度變化不大,聚類的個數越多,有可能類與類之間的相似度越大,類的內部反而沒有相似度,這種聚類也是不好的。舉個極端的例子,有1000個數據,分成1000個類,那麼類的平均畸變程度是0,那麼每個數據都是一類,類與類之間的相似度大,類內部沒有相似性。
K-means算法的思想就是對空間K個點爲中心進行聚類,對靠近他們的對象進行歸類,通過迭代的方法,逐次更新聚類中心(質心)的值,直到得到最好的聚類結果。K-means過程:

首先選擇k個類別的中心點
對任意一個樣本,求其到各類中心的距離,將該樣本歸到距離最短的中心所在的類
聚好類後,重新計算每個聚類的中心點位置
重複2,3步驟迭代,直到k個類中心點的位置不變,或者達到一定的迭代次數,則迭代結束,否則繼續迭代

3. K-means++算法

K-means算法假設聚類爲3類,開始選取每個類的中心點的時候是隨機選取,有可能三個點選取的位置非常近,導致後面每次聚類重新求各類中心的迭代次數增加。K-means++在選取第一個聚類中心點的時候也是隨機選取,當選取第二個中心點的時候,距離當前已經選擇的聚類中心點的距離越遠的點會有更高的概率被選中,假設已經選取n個點,當選取第n+1個聚類中心時,距離當前n個聚類中心點越遠的點越會被選中,這種思想是聚類中心的點離的越遠越好,這樣就大大降低的找到最終聚類各個中心點的迭代次數,提高了效率。

4.spark訓練模型

val model = new KMeans().
//設置聚類的類數
setK(numClusters).
//設置找中心點最大的迭代次數
setMaxIterations(numIterations).
run(parsedData)



object KMeans {

  def main(args: Array[String]) {
    //1 構建Spark對象
    val conf = new SparkConf().setAppName("KMeans").setMaster("local")
    val sc = new SparkContext(conf)

    // 讀取樣本數據1,格式爲LIBSVM format
    val data = sc.textFile("kmeans_data.txt")
val parsedData = data.map(s => Vectors.dense(s.split(' ')
.map(_.toDouble))).cache()
    val numClusters = 4
    val numIterations = 100
    val model = new KMeans().
      //設置聚類的類數
      setK(numClusters).
      //設置找中心點最大的迭代次數
      setMaxIterations(numIterations).
      run(parsedData)
      
    //四個中心點的座標
    val centers = model.clusterCenters
    val k = model.k
    centers.foreach(println)
    println(k)
    //保存模型
//    model.save(sc, "./Kmeans_model")
    //加載模型
    val sameModel = KMeansModel.load(sc, "./Kmeans_model")
    println(sameModel.predict(Vectors.dense(1,1,1)))
    val sqlContext = new SQLContext(sc)
    sqlContext.read.parquet("./Kmeans_model/data").show()
    
  }
}

給Kmeans指定中心點座標:
 

object KMeans2 {

  def main(args: Array[String]) {
    val conf = new SparkConf().setAppName("KMeans2").setMaster("local")
    val sc = new SparkContext(conf)

    val rdd = sc.parallelize(List(
      Vectors.dense(Array(-0.1, 0.0, 0.0)),
      Vectors.dense(Array(9.0, 9.0, 9.0)),
      Vectors.dense(Array(3.0, 2.0, 1.0))))
      
    //指定文件 kmeans_data.txt 中的六個點爲中心點座標。
    val centroids: Array[Vector] = sc.textFile("kmeans_data.txt")
        .map(_.split(" ").map(_.toDouble))
        .map(Vectors.dense(_))
        .collect()
    val model = new KMeansModel(clusterCenters=centroids)
    println("聚類個數 = "+model.k)
    //模型中心點
    model.clusterCenters.foreach { println }
    //預測指定的三條數據
    val result = model.predict(rdd)
    result.collect().foreach(println(_))
  }
}

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章