KNN算法:使用歐式距離計算方法,從源對象集合中選取距離目標節點最近的K個節點,判斷K個節點所屬類別最多的節點,即爲目標節點所屬的類別。
此處只是簡單的實現KNN算法的過程,其中有一些優化的地方不再修改,還請小夥伴自行優化。
KNN的model類:
package com.spring5.bigdata.knn;
/**
* @author yinxf
* @date 2020-05-16
*/
public class KnnNode {
private float x; //X座標
private float y; //Y座標
private Float distance; //目標節點到此節點的距離
private String type; //所屬類別
public KnnNode(float x, float y, String type) {
this.x = x;
this.y = y;
this.type = type;
}
public float getDistance() {
return distance;
}
public void setDistance(float distance) {
this.distance = distance;
}
public float getX() {
return x;
}
public void setX(float x) {
this.x = x;
}
public float getY() {
return y;
}
public void setY(float y) {
this.y = y;
}
public String getType() {
return type;
}
public void setType(String type) {
this.type = type;
}
@Override
public String toString() {
return "Node{" +
"x=" + x +
", y=" + y +
", distance=" + distance +
", type='" + type + '\'' +
'}';
}
}
KNN的測試類,其中包括初始化數據,使用歐式距離計算選取距離目標節點最近的K個節點,
package com.spring5.bigdata.knn;
import java.util.ArrayList;
import java.util.List;
/**
* @author yinxf
* @date 2020-05-16
*/
public class Knn {
//類別
private final static String RED = "RED";
private final static String BLACK = "BLACK";
public static void main(String[] args) {
//初始化所有節點座標
List<KnnNode> totalType = init();
//驗證此節點屬於哪個類別
// KnnNode knnNode = new KnnNode(4,5,""); //所屬類別爲:black
KnnNode knnNode = new KnnNode(3,2,""); //所屬類別爲:red
//計算所有節點到目標節點的距離,歐式距離
totalType = getDistance(totalType,knnNode);
//計算距離目標節點最近的K個節點
int k = 3;
List<KnnNode> kList = getKList(totalType,k);
//計算提供的節點屬於那種類別
String resultType = getNodeType(kList);
System.out.println("目標節點所屬類別爲:" + resultType);
}
/**
* 查找距離目標節點最近的K個節點
* @param totalType
* @param k
* @return
*/
private static List<KnnNode> getKList(List<KnnNode> totalType, int k) {
List<KnnNode> kList = new ArrayList<>(k);
//選出距離目標節點最近的K個節點
for (int i = 0 ; i < totalType.size() ; i++ ) {
KnnNode type = totalType.get(i);
if (i < k){
kList.add(type);
}else {
boolean flag = false;
//判斷當前節點小於K個節點集合中的節點
for (KnnNode knnNode1 : kList) {
if (type.getDistance() < knnNode1.getDistance()){
flag = true;
break;
}
}
//替換距離目標節點最遠的節點
if (flag) {
int index = 0 ;
for (int j = 0; j < k; j++) {
if (kList.get(j).getDistance() > type.getDistance()) {
index = j ;
}
}
kList.remove(index);
kList.add(type);
kList.forEach(list -> System.out.println(list.toString()));
System.out.println("=========================================");
}
}
}
return kList;
}
/**
* 計算所有節點到目標節點的距離
* @param totalType
* @param knnNode
* @return
*/
private static List<KnnNode> getDistance(List<KnnNode> totalType, KnnNode knnNode){
for (int i = 0 ; i < totalType.size() ; i++ ) {
KnnNode type = totalType.get(i);
float distance = distance(type, knnNode);
type.setDistance(distance);
System.out.println( i+"類別爲:【"+ type.getType() + "】 距離爲:【"+distance +"】" );
}
return totalType;
}
/**
* 計算目標節點所屬類別
* @param kList
* @return
*/
private static String getNodeType(List<KnnNode> kList) {
//結算距離目標節點最近的K個節點中的,節點最多的類別是什麼
int redNum = 0;
int blackNum = 0;
for (KnnNode result : kList) {
if (RED.equals(result.getType())){
redNum++;
}else if (BLACK.equals(result.getType())){
blackNum++;
}
}
return blackNum > redNum ? BLACK : RED;
}
/**
* 歐式距離計算公式
* @param source
* @param target
* @return
*/
private static float distance(KnnNode source, KnnNode target) {
float x = source.getX() - target.getX();
float y = source.getY() - target.getY();
float z = x * x + y * y;
float distance = (float) Math.sqrt(z);
return distance;
}
/**
* 初始化節點
* @return
*/
private static List<KnnNode> init() {
List<KnnNode> totalType = new ArrayList<>();
totalType.add(new KnnNode(1,2,RED));
totalType.add(new KnnNode(2,2,RED));
totalType.add(new KnnNode(1,3,RED));
totalType.add(new KnnNode(2,1,RED));
totalType.add(new KnnNode(2,3,RED));
totalType.add(new KnnNode(3,5,BLACK));
totalType.add(new KnnNode(4,6,BLACK));
totalType.add(new KnnNode(3,4,BLACK));
totalType.add(new KnnNode(5,4,BLACK));
totalType.add(new KnnNode(5,3,BLACK));
return totalType;
}
}
測試結果如下:
0類別爲:【RED】 距離爲:【2.0】
1類別爲:【RED】 距離爲:【1.0】
2類別爲:【RED】 距離爲:【2.236068】
3類別爲:【RED】 距離爲:【1.4142135】
4類別爲:【RED】 距離爲:【1.4142135】
5類別爲:【BLACK】 距離爲:【3.0】
6類別爲:【BLACK】 距離爲:【4.1231055】
7類別爲:【BLACK】 距離爲:【2.0】
8類別爲:【BLACK】 距離爲:【2.828427】
9類別爲:【BLACK】 距離爲:【2.236068】
Node{x=1.0, y=2.0, distance=2.0, type='RED'}
Node{x=2.0, y=2.0, distance=1.0, type='RED'}
Node{x=2.0, y=1.0, distance=1.4142135, type='RED'}
=========================================
Node{x=2.0, y=2.0, distance=1.0, type='RED'}
Node{x=2.0, y=1.0, distance=1.4142135, type='RED'}
Node{x=2.0, y=3.0, distance=1.4142135, type='RED'}
=========================================
目標節點所屬類別爲:RED