計算智能——K-means聚類算法的原理和實現(C語言)

計算智能——K-means聚類算法的原理和實現(C語言)

算法定義

k-means聚類算法是一種迭代求解的聚類分析算法。,
k均值聚類是最著名的劃分聚類算法,由於簡潔和效率使得他成爲所有聚類算法中最廣泛使用的。給定一個數據點集合和需要的聚類數目k,k由用戶指定,k均值算法根據某個距離函數反覆把數據分入k個聚類中。

算法原理

k-means算法首先選取k個點作爲初始的聚類中心,然後計算各個樣本到各聚類中心的距離,把每個樣本歸入離它最近的那個聚類中心所在的類; 調整後的新類計算新的聚類中心,如果相鄰兩次的聚類中心沒有任何變化,這說明數據對象調整結束,聚類準則函數f已經收斂。在每次迭代過程中都要考察每個樣本的分類是否正確,若不正確,就要調整。在全部數據調整完後,再修改聚類中心,進入下一次迭代。如果在某一次迭代算法中,所有的數據 對象被正確分類,則不會有調整,聚類中心也不會有任何變化,這標誌着f已經收斂,算法結束。該算法可分爲四個步驟

  1. 選定k箇中心點,選定n個樣本,輸入這些樣本。
  2. 爲每個樣本找到距離其最近的中心點(尋找組織),距離同一中心點最近的點爲一個類,這樣完成了一次聚類。
  3. 判斷聚類前後的樣本點的類別情況是否相同(及兩次聚類的平方誤差是否相同),如果相同,則算法結束,否則進入第四步。
  4. 針對每個類別中的樣本點,計算這些樣本的中心點,以此作爲該類新的中心點,繼續第二步。

算法流程圖

這是

算法實現

       #include <stdio.h>
        #include <string.h>
        #include <stdlib.h>
        #include <math.h>
        #include <time.h>
        #define max 100
        typedef struct
        {
        	float x;
        	float y;
        }Point;
        Point point[max];
        Point mean[max];  ///  保存每個簇的中心點
        
        int center[max];  ///  判斷每個點屬於哪個簇
        int Num;
        int K;
         
        //獲得兩點間直線距離 
        float getDistance(Point point1, Point point2)
        {
        	float d;
        	d = sqrt((point1.x - point2.x) * (point1.x - point2.x) + (point1.y - point2.y) * (point1.y - point2.y));
        	return d;
        }
         
        // 計算每個簇的中心點 把歸屬於該中心點的點相加除以點數 
        void getMean(int center[max])
        {
        	Point tep;
        	int i, j, count = 0;
        	for(i = 0; i < K; i++)
        	{
        		count = 0;
        		tep.x = 0.0;   /// 每算出一個簇的中心點值後清0
        		tep.y = 0.0;
                for(j = 0; j < Num; j++)
        		{
        			if(i == center[j])
        			{
        				count++;
        				tep.x += point[j].x;
        				tep.y += point[j].y;
        			}
        		}
        		tep.x /= count;
        		tep.y /= count;
        		mean[i] = tep;
        	}
        	for(i = 0; i < K; i++)
            {
            	printf("The new center point of %d is : \t( %f, %f )\n", i+1, mean[i].x, mean[i].y);
            }
        }
         
        /// 計算平方誤差函數(x1-x2)^2+(y1-y2)^2  計算當前各點與當前所屬中心點間距離誤差  
        float getError()
        {
        	int i, j;
        	float cnt = 0.0, sum = 0.0;
        	for(i = 0; i < K; i++)
        	{
        		for(j = 0; j < Num; j++)
        		{
        			if(i == center[j])
        			{
        				cnt = (point[j].x - mean[i].x) * (point[j].x - mean[i].x) + (point[j].y - mean[i].y) * (point[j].y - mean[i].y);
        				sum += cnt;
        			}
        		}
        	}
        	return sum;
        }
         
        // 把Num個點聚類
        void cluster()
        {
        	int i, j, q;
        	float min;
        	float distance[Num][K];
        	for(i = 0; i < Num; i++)
        	{
        		min = 9999.0;
        		//獲得(point[i].x,point[i].y)的點到每個中心點的距離 
        		for(j = 0; j < K; j++)
        		{
        			distance[i][j] = getDistance(point[i], mean[j]);
        		}
        		//計算每個點到三個中心點的距離,如果發現有直線距離更短的中心點,則將自己歸入其中,放棄當前中心點 
        		for(q = 0; q < K; q++)
        		{
        			if(distance[i][q] < min)
        			{
        				min = distance[i][q];
                		center[i] = q;
        			}
        		}
        		//輸出這個過程中的該點及其歸類 
        		printf("( %.0f, %.0f )\t in cluster %d\n", point[i].x, point[i].y, center[i] + 1);
        	}
        	printf("-----------------------------\n");
        }
         
        int main()
        {
            int i, j, n = 0;
            float temp1;
            float temp2, t;
            printf("Please input the num:\n");
            scanf("%d", &Num); 
            printf("Please input the k:\n");
            scanf("%d", &K);
            printf("Please input %d coordinates:\n",Num);
            for(i = 0; i < Num; i++){
            	scanf("%f%f",&point[i].x,&point[i].y);
        	}
           
            printf("-----------------------------\n");
         
    		for(i = 0;i < K ; i++){
              	mean[i].x = point[i].x;      /// 初始化k箇中心點
              	mean[i].y = point[i].y;
         	 }
         
            cluster();          /// 第一次根據預設的k個點進行聚類
            temp1 = getError();        ///  第一次平方誤差
            n++;                   ///  n計算形成最終的簇用了多少次
         
            printf("The square error on 1 is: %f\n\n", temp1);
         
            getMean(center);
            cluster();
            temp2 = getError();        ///  根據簇形成新的中心點,並計算出平方誤差
            n++;
         
            printf("The square error on 2 is: %f\n\n", temp2);
         
            while(fabs(temp2 - temp1) != 0)   ///  比較兩次平方誤差 判斷是否相等,不相等繼續迭代
            {
            	temp1 = temp2;
                getMean(center);
            	cluster();
            	temp2 = getError();
            	n++;
            	printf("The square error on %d is: %f\n", n, temp2);
            }
         
            printf("The total round of cluster is: %d\n\n", n);  /// 統計出迭代次數
            system("pause");
            return 0;
        } 

參考資料

Polykovskiy, D. and Novikov, A., Bayesian Methods for Machine Learning

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章