spark mllib機器學習之四 kmeans

數據格式:


*****,114.766907,35.218128,14,*****,***
****,114.969452,35.323708,30,0***,***
*****,114.879410,35.267296,80,***,***
*****,114.766907,35.218128,14,*****,***



package com.agm.kmeans

import java.io.File
import java.io.PrintWriter
import java.io.File
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.mllib.clustering.KMeans
import org.apache.spark.mllib.linalg.Vectors
import org.apache.log4j.{ Level, Logger }
object start {
  def main(args: Array[String]) {
    Logger.getLogger("org").setLevel(Level.ERROR)


    val conf = new SparkConf().setAppName("Simple Application") //給Application命名    
    conf.setMaster("local")


    val sc = new SparkContext(conf)


    val data = sc.textFile("F:\\testData\\spark\\addressAllInfoUTF.txt")
    val data1 = data.map(f => f.split(',')).map(f => (f(1) + " " + f(2)))
    val parsedData = data1.map(s => Vectors.dense(s.split(' ').map(_.trim.toDouble * 1000))).cache()
    //parsedData.foreach(println)
    //設置簇的個數爲3


    val numClusters = 200


    //迭代20次


    val numIterations = 50


    //運行10次,選出最優解


    val runs = 10


    //設置初始K選取方式爲k-means++


    val initMode = "k-means||"


    val clusters = new KMeans().


      setInitializationMode(initMode).


      setK(numClusters).


      setMaxIterations(numIterations).


      run(parsedData)


    //打印出測試數據屬於哪個簇


    //println(parsedData.map(v=> v.toString() + " belong to cluster :" +clusters.predict(v)).collect().mkString("\n"))


    // Evaluateclustering by computing Within Set Sum of Squared Errors


    val WSSSE = clusters.computeCost(parsedData)
    //val count = clusters.clusterCenters(0).size
    println()
    println("WithinSet Sum of Squared Errors = " + WSSSE)
    val res = parsedData.map(f => (clusters.predict(f), 1)).reduceByKey((a, b) => (a + b)).map(f => f._2).collect()


    val a21 = clusters.predict(Vectors.dense(1.2, 1.3))


    val a22 = clusters.predict(Vectors.dense(4.1, 4.2))


    //打印出中心點


    println("Clustercenters:")


    val writer = new PrintWriter(new File("F:\\testData\\spark\\learningScala.txt"))
    var i = 0
    for (center <- clusters.clusterCenters) {
      writer.println((center(0) / 1000) + " " + (center(1) / 1000) + " " + res(i))
      i += 1
      //println(" "+ center)


    }
    writer.close()
    println("Prediction of (1.2,1.3)-->" + a21)


    println("Prediction of (4.1,4.2)-->" + a22)


  }
}
發佈了53 篇原創文章 · 獲贊 14 · 訪問量 19萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章