機器學習:KNN用java代碼實現

KNN算法:使用歐式距離計算方法,從源對象集合中選取距離目標節點最近的K個節點,判斷K個節點所屬類別最多的節點,即爲目標節點所屬的類別。

此處只是簡單的實現KNN算法的過程,其中有一些優化的地方不再修改,還請小夥伴自行優化。

KNN的model類:

package com.spring5.bigdata.knn;

/**
 * @author yinxf
 * @date 2020-05-16
 */
public class KnnNode {
    private float x; //X座標
    private float y; //Y座標
    private Float distance; //目標節點到此節點的距離
    private String type; //所屬類別

    public KnnNode(float x, float y, String type) {
        this.x = x;
        this.y = y;
        this.type = type;
    }

    public float getDistance() {
        return distance;
    }

    public void setDistance(float distance) {
        this.distance = distance;
    }

    public float getX() {
        return x;
    }

    public void setX(float x) {
        this.x = x;
    }

    public float getY() {
        return y;
    }

    public void setY(float y) {
        this.y = y;
    }

    public String getType() {
        return type;
    }

    public void setType(String type) {
        this.type = type;
    }

    @Override
    public String toString() {
        return "Node{" +
                "x=" + x +
                ", y=" + y +
                ", distance=" + distance +
                ", type='" + type + '\'' +
                '}';
    }

}

KNN的測試類,其中包括初始化數據,使用歐式距離計算選取距離目標節點最近的K個節點,

package com.spring5.bigdata.knn;

import java.util.ArrayList;
import java.util.List;

/**
 * @author yinxf
 * @date 2020-05-16
 */
public class Knn {

    //類別
    private final static String RED = "RED";
    private final static String BLACK = "BLACK";


    public static void main(String[] args) {
        //初始化所有節點座標
        List<KnnNode> totalType = init();
        //驗證此節點屬於哪個類別
//        KnnNode knnNode = new KnnNode(4,5,""); //所屬類別爲:black
        KnnNode knnNode = new KnnNode(3,2,""); //所屬類別爲:red

        //計算所有節點到目標節點的距離,歐式距離
        totalType = getDistance(totalType,knnNode);

        //計算距離目標節點最近的K個節點
        int k = 3;
        List<KnnNode> kList = getKList(totalType,k);

        //計算提供的節點屬於那種類別
        String resultType = getNodeType(kList);
        System.out.println("目標節點所屬類別爲:" + resultType);

    }

    /**
     * 查找距離目標節點最近的K個節點
     * @param totalType
     * @param k
     * @return
     */
    private static List<KnnNode> getKList(List<KnnNode> totalType, int k) {
        List<KnnNode> kList = new ArrayList<>(k);
        //選出距離目標節點最近的K個節點
        for (int i = 0 ; i < totalType.size() ; i++ ) {
            KnnNode type = totalType.get(i);
            if (i < k){
                kList.add(type);
            }else {
                boolean flag = false;
                //判斷當前節點小於K個節點集合中的節點
                for (KnnNode knnNode1 : kList) {
                    if (type.getDistance() < knnNode1.getDistance()){
                        flag = true;
                        break;
                    }
                }
                //替換距離目標節點最遠的節點
                if (flag) {
                    int index = 0 ;
                    for (int j = 0; j < k; j++) {
                        if (kList.get(j).getDistance() > type.getDistance()) {
                            index = j ;
                        }
                    }
                    kList.remove(index);
                    kList.add(type);

                    kList.forEach(list -> System.out.println(list.toString()));
                    System.out.println("=========================================");
                }
            }
        }
        return kList;
    }

    /**
     * 計算所有節點到目標節點的距離
     * @param totalType
     * @param knnNode
     * @return
     */
    private static List<KnnNode> getDistance(List<KnnNode> totalType, KnnNode knnNode){
        for (int i = 0 ; i < totalType.size() ; i++ ) {
            KnnNode type = totalType.get(i);
            float distance = distance(type, knnNode);
            type.setDistance(distance);
            System.out.println( i+"類別爲:【"+ type.getType() + "】  距離爲:【"+distance +"】" );
        }
        return totalType;
    }

    /**
     * 計算目標節點所屬類別
     * @param kList
     * @return
     */
    private static String getNodeType(List<KnnNode> kList) {
        //結算距離目標節點最近的K個節點中的,節點最多的類別是什麼
        int redNum = 0;
        int blackNum = 0;
        for (KnnNode result : kList) {
            if (RED.equals(result.getType())){
                redNum++;
            }else if (BLACK.equals(result.getType())){
                blackNum++;
            }
        }
        return blackNum > redNum ? BLACK : RED;
    }


    /**
     * 歐式距離計算公式
     * @param source
     * @param target
     * @return
     */
    private static float distance(KnnNode source, KnnNode target) {
        float x = source.getX() - target.getX();
        float y = source.getY() - target.getY();
        float z = x * x + y * y;
        float distance = (float) Math.sqrt(z);
        return distance;
    }

    /**
     * 初始化節點
     * @return
     */
    private static List<KnnNode> init() {
        List<KnnNode> totalType = new ArrayList<>();
        totalType.add(new KnnNode(1,2,RED));
        totalType.add(new KnnNode(2,2,RED));
        totalType.add(new KnnNode(1,3,RED));
        totalType.add(new KnnNode(2,1,RED));
        totalType.add(new KnnNode(2,3,RED));
        totalType.add(new KnnNode(3,5,BLACK));
        totalType.add(new KnnNode(4,6,BLACK));
        totalType.add(new KnnNode(3,4,BLACK));
        totalType.add(new KnnNode(5,4,BLACK));
        totalType.add(new KnnNode(5,3,BLACK));
        return totalType;
    }
}

測試結果如下:

0類別爲:【RED】  距離爲:【2.0】
1類別爲:【RED】  距離爲:【1.0】
2類別爲:【RED】  距離爲:【2.236068】
3類別爲:【RED】  距離爲:【1.4142135】
4類別爲:【RED】  距離爲:【1.4142135】
5類別爲:【BLACK】  距離爲:【3.0】
6類別爲:【BLACK】  距離爲:【4.1231055】
7類別爲:【BLACK】  距離爲:【2.0】
8類別爲:【BLACK】  距離爲:【2.828427】
9類別爲:【BLACK】  距離爲:【2.236068】
Node{x=1.0, y=2.0, distance=2.0, type='RED'}
Node{x=2.0, y=2.0, distance=1.0, type='RED'}
Node{x=2.0, y=1.0, distance=1.4142135, type='RED'}
=========================================
Node{x=2.0, y=2.0, distance=1.0, type='RED'}
Node{x=2.0, y=1.0, distance=1.4142135, type='RED'}
Node{x=2.0, y=3.0, distance=1.4142135, type='RED'}
=========================================
目標節點所屬類別爲:RED

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章