聚類系列-KMEANS

k-means聚類算法是聚類算法中應用非常廣泛的一種算法。它是屬於劃分法的一種,是一種基於距離的聚類方法,在聚類的開始需要指定一個K值,表示需要聚類的數目。

k-means聚類算法的思想非常容易理解:拿到待聚類的N個樣本和需要聚類的數目K(K<N)(k怎麼選擇後邊會介紹)以後,隨機的在N個樣本中選擇K個點作爲初始的聚類中心,然後計算剩餘的點與每一個聚類中心點的距離(當然有許多種計算距離的方法,下邊會介紹)並選擇與自身距離最小的聚類中心點爲聚類中心(也就是選擇與自身距離最小的那個聚類中心點同一個類別),這樣第一次迭代結束,接下來就是看這次的迭代能否達到你設定的目標(就是迭代終止條件,後邊會介紹),若是達到了那麼聚類結束,否則的話進行下一輪的迭代。下一輪的迭代其實就是重新計算聚類中心點(有可能不是你的樣本數據點),然後計算其他點與新聚類中心的距離,重新選擇類別。就這樣依次迭代直到達到設定的終止條件。用算法的形式表示爲:

    設定聚類數目K

    在N個樣本中隨機的選擇K個樣本點作爲初始的聚類中心點;

    repeat

        計算剩餘的點和每一個聚類中心點的距離並選擇與自身距離最小的聚類中心點爲聚類中心;
        重新計算每一個聚類的聚類中心;

    until   迭代的終止條件

    舉個具體的例子來說明:

        例:我有平面上的8(N=8)個樣本點:p0(2,3)、p1(2,5)、p2(3,5)、p3(4,6)、p4(5,4)、p5(3,1)、p6(0,6)、p7(7,7),我想聚成3個類別。

        按照上邊的算法講的步驟:

            我設定聚類的數目K=3;

            之後我在這8個樣本中隨機的選擇3個點爲初始的聚類中心點,比方說我選擇了p1、p4、p5爲初始的聚類中心點。

             repeat

                 剩餘的點有p0、p2、p3、p6、p7,分別計算他們與聚類中心點p1、p4、p5距離,這裏我們選擇距離的計算方法爲歐氏距離。通過計算(拿p0爲例)可得

                (p0,p1)=2     (p0,p4)=sqrt(10)    (p0,p5)=sqrt(5),由距離我們可以得到p0和p1的距離最小,所以將p0放到p1所在的類中。 用同樣的方法我們可以

                將p2、p3、p6、p7分別放到p1、p1、p1、p4。這樣我們得到第一輪的三個類簇c1(p0,p1,p2,p3,p6),c2(p4,p7),c3(p5)。接下來我們重新計算聚

                 類中心點,用的方法就是取類簇中所有樣本的均值,以類簇c2爲例計算可得c2的新的聚類中心爲((5+7)/2,(4+7)/2)=(6,5.5),同樣的方法可以得到其 他

                兩個類簇的新的聚類中心c1=(2.2,5),c3=(3,1)。

                 注意(只是說第二輪,以後的每一輪一次類推):若是下一輪有迭代,剩餘點爲p0、p1、p2、p3p4、p6、 p7, 因爲p5依然是聚類中心點。

            util   聚類中心不變或者聚類中心的變化量小於某一個閾值或者達到迭代次數。

    k-means聚類算法中有幾個關鍵點需要注意一下:

        K值的選擇:

             在實際的應用中k值的選擇一般是靠經驗來選擇的,多試幾次,選擇其中對你的所要解決的問題最好的聚類數目。但是在學術上或者其他一些博客中也有一些對K值選擇

             的算 法,不過本人沒有太去研究,不好做出評論。在這篇博客中點擊打開鏈接,有對K值的選擇和對初始聚類中心的選擇的介紹,有興趣的可以去看一下。

         初始聚類中心點的選擇:

             工程上對於初始聚類中心點的選擇,一般使用隨機選擇然後去迭代,都能夠取的不錯的效果。當然,也可以去參照K值選擇中的那篇博客去求解。

        與聚類中心距離的計算:

            在工程中我使用的比較多的是歐氏距離和餘弦夾角,這個想必不用介紹了,大家都應該清楚的。

        迭代條件:

            迭代條件關係到你的聚類的效果的問題,一般迭代條件會是三種情況:第一種是你設定迭代的次數(50、100或者更多。。。。自己設定),當迭代次數達到你所設定

             次數,聚類自動終止;第二種是聚類中心不變,也就是上一輪迭代和下一輪迭代的所有的聚類中心都不變化了,這種情況下一般難以收斂,通常會和迭代次數一起使用

            ;第三種情況就是設定一個閾值,當所有的聚類中心點的變化範圍都不大於你設置的閾值的時候,聚類結束,也可以與迭代次數一塊使用。

     好了,上邊講的這麼多都是從我的實際應用中總結的,若是其他人有更好的見解,請不吝賜教。下邊上代碼(Java版,我可能會在寫一份Python的供大家使用):

    主類(Cluster_Kmeans.java):

package com.zc.test;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;

import com.zc.source.kmeans;
import com.zc.source.kmeans_data;

public class Cluster_Kmeans {

	public static void main(String[] args) {

		ArrayList<double[]> dimension = new ArrayList<double[]>();//存放樣本詞的向量值
		List<String> dataSet = new LinkedList<String>();//存放樣本中的詞

		File Input = new File("E:\\test_csdn\\test.txt");// 輸入數據

		File Output = new File("E:\\test_csdn\\rel.txt");// 輸出結果

		if (Output.exists()) {
			Output.delete();
		}

		FileInputStream fis = null;
		InputStreamReader isr = null;
		BufferedReader br = null;

		FileOutputStream fos = null;
		OutputStreamWriter osw = null;
		BufferedWriter bw = null;

		try {
			fis = new FileInputStream(Input);
			isr = new InputStreamReader(fis, "utf-8");
			br = new BufferedReader(isr);

			fos = new FileOutputStream(Output);
			osw = new OutputStreamWriter(fos, "utf-8");
			bw = new BufferedWriter(osw);

			String line = br.readLine();
			String s[] = null;
			while (line != null) {
				double b[] = new double[200];
				s = line.split(" ");//樣本中的數據包詞語和對應的向量值,中間使用空格隔開的,所以取出來的時候要用空格劃分
				for (int i = 1; i < s.length; i++) {
					b[i - 1] = Double.parseDouble(s[i]);
				}
				dimension.add(b);//放入某個詞的向量值
				dataSet.add(s[0]);//放入詞語
				line = br.readLine();
			}
			System.out.println("數據加載完畢---------------");
			
			double[][] ff = new double[dataSet.size()][200];
			for (int i = 0; i < dimension.size(); i++) {
				ff[i] = dimension.get(i);//將詞語和對應的向量值對應起來
			}
			 // 初始化數據結構
			kmeans_data data = new kmeans_data(ff, dataSet.size(), 200);

			// 調用doKmeans方法進行聚類,參數列表:聚類數目,數據集,迭代次數,聚類中心變化閾值
			kmeans.doKmeans(2, data, 4000,0.0);

			// 輸出聚類結果
			for (int i = 0; i < dataSet.size(); i++) {
				bw.write(dataSet.get(i));
				bw.write(" ");
				bw.write(String.valueOf(data.labels[i]));
				bw.write("\r\n");
			}
		} catch (Exception ex) {
			ex.printStackTrace();
		} finally {
			try {
				bw.flush();
				osw.flush();
				fos.flush();

				fos.close();
				br.close();
				isr.close();
				fis.close();
			} catch (IOException e) {
				e.printStackTrace();
			}
		}

	}
}
    數據類(kmeans_data.java):
package com.zc.source;

public class kmeans_data {
	public double[][] data;//存放詞語和向量
	public double[] dis;//每個樣本和聚類中心的距離
	public int length;//樣本大小N
	public int dim;//向量的維度
	public int[] labels;//樣本所屬的類簇的標籤
	public double[][] centers;//存放聚類中心
	public int[] centerCounts;//存放某個類簇中含有的樣本的個數
	
	public kmeans_data(double[][] data, int length, int dim) {
		this.data = data;
		this.length = length;
		this.dim = dim;
	}
}
    k-means的處理類(kmeans.java):
package com.zc.source;

import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;

public class kmeans {

	/**
	 * double[][] 元素全置0
	 * 
	 * @param matrix
	 *            double[][]
	 * @param highDim
	 *            int
	 * @param lowDim
	 *            int <br/>
	 *            double[highDim][lowDim]
	 */
	private static void setDouble2Zero(double[][] matrix, int highDim, int lowDim) {
		for (int i = 0; i < highDim; i++) {
			for (int j = 0; j < lowDim; j++) {
				matrix[i][j] = 0;
			}
		}
	}

	/**
	 * 拷貝源二維矩陣元素到目標二維矩陣。 foreach (dests[highDim][lowDim] =
	 * sources[highDim][lowDim]);
	 * 
	 * @param dests
	 *            double[][]
	 * @param sources
	 *            double[][]
	 * @param highDim
	 *            int
	 * @param lowDim
	 *            int
	 */
	private static void copyCenters(double[][] dests, double[][] sources, int highDim, int lowDim) {
		for (int i = 0; i < highDim; i++) {
			for (int j = 0; j < lowDim; j++) {
				dests[i][j] = sources[i][j];
			}
		}
	}

	/**
	 * 更新聚類中心座標
	 * 
	 * @param k
	 *            int 分類個數
	 * @param data
	 *            kmeans_data
	 */
	private static void updateCenters(int k, kmeans_data data) {
		double[][] centers = data.centers;
		setDouble2Zero(centers, k, data.dim);
		int[] labels = data.labels;
		int[] centerCounts = data.centerCounts;
		for (int i = 0; i < data.dim; i++) {
			for (int j = 0; j < data.length; j++) {
				centers[labels[j]][i] += data.data[j][i];
			}
		}
		for (int i = 0; i < k; i++) {
			for (int j = 0; j < data.dim; j++) {
				centers[i][j] = centers[i][j] / centerCounts[i];
			}
		}
	}

	/**
	 * 計算兩點餘弦值
	 * 
	 * @param pa
	 *            double[]
	 * @param pb
	 *            double[]
	 * @param dim
	 *            int 維數
	 * @return double 距離
	 */
	public static double dist(double[] pa, double[] pb, int dim) {
		double mpa = 0;// pa的莫
		double mpb = 0;// pb的莫
		double proab = 0;// pa和pb的向量積
		for (int i = 0; i < dim; i++) {
			proab = proab + pa[i] * pb[i];
			mpa = mpa + pa[i] * pa[i];
			mpb = mpb + pb[i] * pb[i];
		}
		double temp = 0;
		temp = Math.sqrt(mpa) * Math.sqrt(mpb);
		double result = proab / temp;
		return result;
	}
	/**
	 * 計算兩次聚類中心的歐式距離
	 * 
	 * @param pa
	 *            double[]
	 * @param pb
	 *            double[]
	 * @param dim
	 *            int 維數
	 * @return double 距離
	 */
	public static double distcen(double[] pa, double[] pb, int dim) {
		double rv = 0;
		for (int i = 0; i < dim; i++) {
			double temp = pa[i] - pb[i];
			temp = temp * temp;
			rv += temp;
		}
		return Math.sqrt(rv);
	}

	/**
	 * 做Kmeans運算
	 * 
	 * @param k
	 *            int 聚類個數
	 * @param data
	 *            kmeans_data kmeans數據類
	 * @param param
	 *            kmeans_param kmeans參數類
	 * @return kmeans_result kmeans運行信息類
	 */
	public static void doKmeans(int k, kmeans_data data, int maxAttempts,double criteria) {
		// 預處理
		double[][] centers = new double[k][data.dim]; // 聚類中心點集
		data.centers = centers;
		int[] centerCounts = new int[k]; // 各聚類的包含點個數
		data.centerCounts = centerCounts;
		Arrays.fill(centerCounts, 0);
		int[] labels = new int[data.length]; // 各個點所屬聚類標號
		data.labels = labels;
		double[] dis = new double[data.length]; // 各個點於聚類中心的距離
		data.dis = dis;
		double[][] oldCenters = new double[k][data.dim]; // 臨時緩存舊的聚類中心座標

		// 初始化聚類中心(隨機選擇data內的k個不重複點)
		Random rn = new Random();
		List<Integer> seeds = new LinkedList<Integer>();
		while (seeds.size() < k) {
			int randomInt = rn.nextInt(data.length);
			if (!seeds.contains(randomInt)) {
				seeds.add(randomInt);
			}
		}
		Collections.sort(seeds);
		for (int i = 0; i < k; i++) {
			int m = seeds.remove(0);
			for (int j = 0; j < data.dim; j++) {
				centers[i][j] = data.data[m][j];
			}
		}

		// 第一輪迭代
		for (int i = 0; i < data.length; i++) {
			double maxDist = dist(data.data[i], centers[0], data.dim);
			int label = 0;
			for (int j = 1; j < k; j++) {
				double tempDist = dist(data.data[i], centers[j], data.dim);
				if (tempDist > maxDist) {
					maxDist = tempDist;
					label = j;
				}
			}
			dis[i] = maxDist;
			labels[i] = label;
			centerCounts[label]++;
		}
		updateCenters(k, data);//更新聚類中心
		copyCenters(oldCenters, centers, k, data.dim);//賦值聚類中心

		// 迭代預處理
		int attempts = 1;
		boolean[] flags = new boolean[k]; // 標記哪些中心被修改過

		int it = 2;
		// 迭代
		iterate: while (attempts < maxAttempts) { // 迭代次數不超過最大值,最大中心改變量不超過閾值ֵ
			for (int i = 0; i < k; i++) { // 初始化中心點“是否被修改過”標記
				flags[i] = false;
			}
			for (int i = 0; i < data.length; i++) { // 遍歷data內所有點
				double maxDist = dist(data.data[i], centers[0], data.dim);
				int label = 0;
				for (int j = 1; j < k; j++) {
					double tempDist = dist(data.data[i], centers[j], data.dim);
					if (tempDist > maxDist) {
						maxDist = tempDist;
						label = j;
					}
				}
				if (label != labels[i]) { // 如果當前點被聚類到新的類別則做更新
					int oldLabel = labels[i];
					labels[i] = label;
					centerCounts[oldLabel]--;
					centerCounts[label]++;
					flags[oldLabel] = true;
					flags[label] = true;
				}
				dis[i] = maxDist;
			}
			updateCenters(k, data);
			attempts++;

			// 得到被修改過的中心點最大修改量ֵ
			double maxDist = 0;
			for (int i = 0; i < k; i++) {
				if (flags[i]) {
					double tempDist = distcen(centers[i], oldCenters[i], data.dim);
					if (maxDist < tempDist) {
						maxDist = tempDist;
					}
					for (int j = 0; j < data.dim; j++) { // 更新oldCenter
						oldCenters[i][j] = centers[i][j];
					}
				}
			}

			System.out.println("迭代第" + it + "次");
			it++;
			if (maxDist == criteria) {//查看被修改過的中心點最大修改量是否超過閾值
				break iterate;
			}
		}
	}
}

    以上爲全部的代碼,我的數據格式爲:word v1 v2 v3 ........具體的例子就是:中國 0.44 0.56 0.78 0.09 ...中間是用空格隔開的,因爲在我的程序中我用的

    樣本是詞語加上200維的向量,所以在我的主程序有些double數組中我直接寫的是200,其他人若是用的時候需要根據具體情況去修改這個值。結果的格

    式爲:word label 例如:中國 1,說明中國這個詞屬於類簇1。

            

                

                      

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章