計算智能——K-means聚類算法的原理和實現(C語言)
算法定義
k-means聚類算法是一種迭代求解的聚類分析算法。,
k均值聚類是最著名的劃分聚類算法,由於簡潔和效率使得他成爲所有聚類算法中最廣泛使用的。給定一個數據點集合和需要的聚類數目k,k由用戶指定,k均值算法根據某個距離函數反覆把數據分入k個聚類中。
算法原理
k-means算法首先選取k個點作爲初始的聚類中心,然後計算各個樣本到各聚類中心的距離,把每個樣本歸入離它最近的那個聚類中心所在的類; 調整後的新類計算新的聚類中心,如果相鄰兩次的聚類中心沒有任何變化,這說明數據對象調整結束,聚類準則函數f已經收斂。在每次迭代過程中都要考察每個樣本的分類是否正確,若不正確,就要調整。在全部數據調整完後,再修改聚類中心,進入下一次迭代。如果在某一次迭代算法中,所有的數據 對象被正確分類,則不會有調整,聚類中心也不會有任何變化,這標誌着f已經收斂,算法結束。該算法可分爲四個步驟
- 選定k箇中心點,選定n個樣本,輸入這些樣本。
- 爲每個樣本找到距離其最近的中心點(尋找組織),距離同一中心點最近的點爲一個類,這樣完成了一次聚類。
- 判斷聚類前後的樣本點的類別情況是否相同(及兩次聚類的平方誤差是否相同),如果相同,則算法結束,否則進入第四步。
- 針對每個類別中的樣本點,計算這些樣本的中心點,以此作爲該類新的中心點,繼續第二步。
算法流程圖
算法實現
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <math.h>
#include <time.h>
#define max 100
typedef struct
{
float x;
float y;
}Point;
Point point[max];
Point mean[max]; /// 保存每個簇的中心點
int center[max]; /// 判斷每個點屬於哪個簇
int Num;
int K;
//獲得兩點間直線距離
float getDistance(Point point1, Point point2)
{
float d;
d = sqrt((point1.x - point2.x) * (point1.x - point2.x) + (point1.y - point2.y) * (point1.y - point2.y));
return d;
}
// 計算每個簇的中心點 把歸屬於該中心點的點相加除以點數
void getMean(int center[max])
{
Point tep;
int i, j, count = 0;
for(i = 0; i < K; i++)
{
count = 0;
tep.x = 0.0; /// 每算出一個簇的中心點值後清0
tep.y = 0.0;
for(j = 0; j < Num; j++)
{
if(i == center[j])
{
count++;
tep.x += point[j].x;
tep.y += point[j].y;
}
}
tep.x /= count;
tep.y /= count;
mean[i] = tep;
}
for(i = 0; i < K; i++)
{
printf("The new center point of %d is : \t( %f, %f )\n", i+1, mean[i].x, mean[i].y);
}
}
/// 計算平方誤差函數(x1-x2)^2+(y1-y2)^2 計算當前各點與當前所屬中心點間距離誤差
float getError()
{
int i, j;
float cnt = 0.0, sum = 0.0;
for(i = 0; i < K; i++)
{
for(j = 0; j < Num; j++)
{
if(i == center[j])
{
cnt = (point[j].x - mean[i].x) * (point[j].x - mean[i].x) + (point[j].y - mean[i].y) * (point[j].y - mean[i].y);
sum += cnt;
}
}
}
return sum;
}
// 把Num個點聚類
void cluster()
{
int i, j, q;
float min;
float distance[Num][K];
for(i = 0; i < Num; i++)
{
min = 9999.0;
//獲得(point[i].x,point[i].y)的點到每個中心點的距離
for(j = 0; j < K; j++)
{
distance[i][j] = getDistance(point[i], mean[j]);
}
//計算每個點到三個中心點的距離,如果發現有直線距離更短的中心點,則將自己歸入其中,放棄當前中心點
for(q = 0; q < K; q++)
{
if(distance[i][q] < min)
{
min = distance[i][q];
center[i] = q;
}
}
//輸出這個過程中的該點及其歸類
printf("( %.0f, %.0f )\t in cluster %d\n", point[i].x, point[i].y, center[i] + 1);
}
printf("-----------------------------\n");
}
int main()
{
int i, j, n = 0;
float temp1;
float temp2, t;
printf("Please input the num:\n");
scanf("%d", &Num);
printf("Please input the k:\n");
scanf("%d", &K);
printf("Please input %d coordinates:\n",Num);
for(i = 0; i < Num; i++){
scanf("%f%f",&point[i].x,&point[i].y);
}
printf("-----------------------------\n");
for(i = 0;i < K ; i++){
mean[i].x = point[i].x; /// 初始化k箇中心點
mean[i].y = point[i].y;
}
cluster(); /// 第一次根據預設的k個點進行聚類
temp1 = getError(); /// 第一次平方誤差
n++; /// n計算形成最終的簇用了多少次
printf("The square error on 1 is: %f\n\n", temp1);
getMean(center);
cluster();
temp2 = getError(); /// 根據簇形成新的中心點,並計算出平方誤差
n++;
printf("The square error on 2 is: %f\n\n", temp2);
while(fabs(temp2 - temp1) != 0) /// 比較兩次平方誤差 判斷是否相等,不相等繼續迭代
{
temp1 = temp2;
getMean(center);
cluster();
temp2 = getError();
n++;
printf("The square error on %d is: %f\n", n, temp2);
}
printf("The total round of cluster is: %d\n\n", n); /// 統計出迭代次數
system("pause");
return 0;
}
參考資料
Polykovskiy, D. and Novikov, A., Bayesian Methods for Machine Learning