[機器學習] 聚類算法的輪廓係數,java實現

這次實現一個輪廓係數(wiki baidu)。目的是爲了評估聚類效果的好壞。

我比較推薦大家觀看wiki的說法,百度裏面的有些說的不是很明白,比如百度百科中的這句話就很費勁 (計算 b(i) = min (i向量到所有非本身所在簇的點的平均距離)

下面是wiki的輪廓係數的說明,大體說一下我的理解: 

a(i)是中心點到自己cluster中的平均距離。

b(i)是中心點到其他cluster的各個距離中的的最小值,下面的就是兩者中的最大值。

Assume the data have been clustered via any technique, such as k-means, into {\displaystyle k}k clusters. For each datum {\displaystyle i}i, let {\displaystyle a(i)}a(i) be the average dissimilarity of {\displaystyle i}i with all other data within the same cluster. We can interpret {\displaystyle a(i)}a(i) as how well {\displaystyle i}i is assigned to its cluster (the smaller the value, the better the assignment). We then define the average dissimilarity of point {\displaystyle i}i to a cluster {\displaystyle c}c as the average of the distance from {\displaystyle i}i to all points in {\displaystyle c}c.

Let {\displaystyle b(i)}b(i) be the lowest average dissimilarity of {\displaystyle i}i to any other cluster, of which {\displaystyle i}i is not a member. The cluster with this lowest average dissimilarity is said to be the "neighbouring cluster" of {\displaystyle i}i because it is the next best fit cluster for point {\displaystyle i}i. We now define a silhouette:



代碼如下:

package com.mj.datamining.test;

import java.util.ArrayList;
import java.util.List;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.clustering.KMeans;
import org.apache.spark.mllib.clustering.KMeansModel;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.rdd.RDD;

import scala.Tuple2;

public class KMeansTest {

	public static void main(String[] args) {
		init(kMeansPre());
	}
	
	public static JavaSparkContext kMeansPre() {
		SparkConf conf = new SparkConf().setAppName("Kmeans").setMaster("local[2]");
		JavaSparkContext jsc = new JavaSparkContext(conf);
		return jsc;
	}
	
	/**
	 *  0.0 0.0 0.0
		0.1 0.1 0.1
		0.2 0.2 0.2
		9.0 9.0 9.0
		9.1 9.1 9.1
		9.2 9.2 9.2
	 * @param jsc
	 */
	public static void init(JavaSparkContext jsc) {
		double[] data1 = {0.0,0.0,0.0};
		double[] data2 = {0.1,0.1,0.1};
		double[] data3 = {0.2,0.2,0.2};
		double[] data4 = {9.0,9.0,9.0};
		double[] data5 = {9.1,9.1,9.1};
		double[] data6 = {9.2,9.2,9.2};
		
		List<Vector> preData = new ArrayList<>();
		Vector v1 = Vectors.dense(data1);
		Vector v2 = Vectors.dense(data2);
		Vector v3 = Vectors.dense(data3);
		Vector v4 = Vectors.dense(data4);
		Vector v5 = Vectors.dense(data5);
		Vector v6 = Vectors.dense(data6);
		preData.add(v1);
		preData.add(v2);
		preData.add(v3);
		preData.add(v4);
		preData.add(v5);
		preData.add(v6);
		JavaRDD<Vector> data = jsc.parallelize(preData);

	    // Cluster the data into two classes using KMeans
	    int numClusters = 2;
	    int numIterations = 20;
	    KMeansModel clusters = KMeans.train(data.rdd(), numClusters, numIterations);

	    JavaRDD<Integer> clusterResult = clusters.predict(data);
	    clusters.clusterCenters();
	    clusterResult.collect();
	    
	    double coef = silhouetteCoefficient(data.collect(),clusterResult.collect(),0,clusters.clusterCenters()[0], clusters.clusterCenters().length);
	    
	    System.out.println("Within Set Sum of Squared Errors = " + coef);

	}
	
	   private static double euclideanDistance(double[] data, double[] center) {
	    	if(data.length == 0 || data == null || center.length == 0 || center == null) {
	    		return 0.0;
	    	} else if(center.length != data.length) {
	    		throw new RuntimeException("執行的時候數據長度和中心長度不一致。");
	    	}
	    	
	    	double sum = 0.0;
	    	
	    	for(int i = 0; i < data.length; i++) {
	    		sum += Math.pow(data[i] - center[i] , 2);
	    	}
	    	
	    	return Math.sqrt(sum);
	    		
	    }
	   
	    /**
	     * a(i) - b(i) / max(a(i), b(i))
	     * a(i) the average of same cluster
	     * b(i) the min average of not same cluster
	     * @param data
	     * @param result
	     * @param flag
	     * @param center
	     * @return
	     */
	    private static double silhouetteCoefficient(List<Vector> data, List<Integer> result, int flag, Vector center, int centerSize) {
	    	double sameClusterSum = 0.0;
	    	double otherClusterSum = 0.0;
	    	double min = Double.MAX_VALUE;
	    	
	    	for(int j = 0; j < centerSize; j++) {
	    		if(j != flag) {
	    			for(int i = 0; i < data.size(); i++) {
	    				if(result.get(i) == j) {
	    					otherClusterSum += euclideanDistance(data.get(i).toArray(), center.toArray());
	    				}
	    			}
		    		min = min<otherClusterSum ? min:otherClusterSum; //非同一cluster裏面的最短distance
	    		}
	    	}
	    	
	    	for(int i = 0; i < data.size(); i++) {
	    		if(result.get(i)==flag) {
	    			sameClusterSum += euclideanDistance(data.get(i).toArray(), center.toArray());
	    		} 
	    	}
	    	
	    	double coef = (min/data.size() - sameClusterSum/data.size()) / Math.max(sameClusterSum/data.size(), min/data.size());
			return coef;
	    }
}


可能有考慮不足的地方,謝謝。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章