機器學習-KNN算法原理 && Spark實現

不懂算法的數據開發者不是一個好的算法工程師,還記得研究生時候,導師講過的一些數據挖掘算法,頗有興趣,但是無奈工作後接觸少了,數據工程師的鄙視鏈,模型>實時>離線數倉>ETL工程師>BI工程師(不喜勿噴哈),現在做的工作主要是離線數倉,當然前期也做過一些ETL的工作,爲了職業的長遠發展,拓寬自己的技術邊界,有必要逐步深入實時和模型,所以從本篇文章開始,也是列個FLAG,深入學習實時和模型部分。

改變自己,從提升自己不擅長領域的事情開始。

1. KNN - K近鄰算法簡介

首先,KNN是一種分類算法,有監督的機器學習,將訓練集的類別打標籤,當測試對象和訓練對象完全匹配時候,就可以對其進行分類,但是測試對象與訓練對象的多個類,如何匹配呢,前面可以判別是否測試對象術語某個訓練對象,但是如果是多個訓練對象類,那如何解決這種問題呢,所以就有了KNN,KNN是通過測量不同特徵值之間的距離進行分類。它的思路是:如果一個樣本在特徵空間中的k個最相似(即特徵空間中最鄰近)的樣本中的大多數屬於某一個類別,則該樣本也屬於這個類別,其中K通常是不大於20的整數。KNN算法中,所選擇的鄰居都是已經正確分類的對象。該方法在定類決策上只依據最鄰近的一個或者幾個樣本的類別來決定待分樣本所屬的類別

file
KNN算法的核心思想是,如果一個樣本在特徵空間中的K個最相鄰的樣本中的大多數屬於某一個類別,則該樣本也屬於這個類別,並具有這個類別上樣本的特性。該方法在確定分類決策上只依據最鄰近的一個或者幾個樣本的類別來決定待分樣本所屬的類別。KNN方法在類別決策時,只與極少量的相鄰樣本有關。由於KNN方法主要靠周圍有限的鄰近的樣本,而不是靠判別類域的方法來確定所屬類別的,因此對於類域的交叉或重疊較多的待分樣本集來說,KNN方法較其他方法更爲適合。

2.KNN 算法流程

2.1 準備數據,對數據進行預處理 。

2.2 計算測試樣本點(也就是待分類點)到其他每個樣本點的距離。

2.3 對每個距離進行排序,然後選擇出距離最小的K個點 。

2.4 對K個點所屬的類別進行比較,根據少數服從多數的原則,將測試樣本點歸入在K個點中佔比最高的那一

3. KNN算法優缺點

優點:易於理解,實現起了很方便,無需預估參數,無需訓練

缺點:數據集中如果某個類的數據量很大,那麼勢必導致測試集合跑到這個類的更多,因爲離這些點較近的概率也較大

4.KNN算法Spark實現

4.1 數據下載和說明

鏈接:https://pan.baidu.com/s/1FmFxSrPIynO3udernLU0yQ提取碼:hell
複製這段內容後打開百度網盤手機App,操作更方便哦

鳶尾花數據集,數據集包含3類共150調數據,每類含50個數據,每條記錄含4個特徵:花萼長度、花萼寬度、花瓣長度、花瓣寬度

過這4個 特徵預測鳶尾花卉屬於(iris-setosa, iris-versicolour, iris-virginica)中的哪一品種

4.2 實現

package com.hoult.work

import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}

object KNNDemo {
  def main(args: Array[String]): Unit = {

    //1.初始化
    val conf=new SparkConf().setAppName("SimpleKnn").setMaster("local[*]")
    val sc=new SparkContext(conf)
    val K=15

    //2.讀取數據,封裝數據
    val data: RDD[LabelPoint] = sc.textFile("data/lris.csv")
      .map(line => {
        val arr = line.split(",")
        if (arr.length == 6) {
          LabelPoint(arr.last, arr.init.map(_.toDouble))
        } else {
          println(arr.toBuffer)
          LabelPoint(" ", arr.map(_.toDouble))
        }
      })


    //3.過濾出樣本數據和測試數據
    val sampleData=data.filter(_.label!=" ")
    val testData=data.filter(_.label==" ").map(_.point).collect()

    //4.求每一條測試數據與樣本數據的距離
    testData.foreach(elem=>{
      val distance=sampleData.map(x=>(getDistance(elem,x.point),x.label))
      //獲取距離最近的k個樣本
      val minDistance=distance.sortBy(_._1).take(K)
      //取出這k個樣本的label並且獲取出現最多的label即爲測試數據的label
      val labels=minDistance.map(_._2)
        .groupBy(x=>x)
        .mapValues(_.length)
        .toList
        .sortBy(_._2).reverse
        .take(1)
        .map(_._1)
      printf(s"${elem.toBuffer.mkString(",")},${labels.toBuffer.mkString(",")}")
      println()
    })
    sc.stop()

  }

  case class LabelPoint(label:String,point:Array[Double])

  import scala.math._

  def getDistance(x:Array[Double],y:Array[Double]):Double={
    sqrt(x.zip(y).map(z=>pow(z._1-z._2,2)).sum)
  }
}

完整代碼:https://github.com/hulichao/bigdata-spark/blob/master/src/main/scala/com/hoult/work/KNNDemo.scala
吳邪,小三爺,混跡於後臺,大數據,人工智能領域的小菜鳥。
更多請關注
file

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章