KNN和K-means詳細介紹

在上一篇博文中介紹了KNN算法(https://blog.csdn.net/jodie123456/article/details/101595943),接下來繼續介紹K-means算法:

K-means(K均值聚類):(屬於無監督學習)

K-means算法是一種簡單的迭代型聚類算法,採用聚類作爲相似性指標,從而發現給定數據集中的K個類,並且每一個類的中心是根據所有值得平均值得到的,每一個類用聚類來描述。對於給定的一個數據集X,需要分得得類別K,選取歐式距離作爲相似性指標,聚類目標是使得各類的聚類平方和最小,最小化下面公式:

                                                        

其中n是數據集的大小,是聚類中心。k是要聚類的類別數。結合最小二乘法和拉格朗日原理,聚類中心爲對應類別中各數據點的平均值,同時爲了使得算法收斂,在迭代過程中,應使最終的聚類中心儘可能的不變。最後說下如何評估聚類的好壞,通俗的就是說類內間距小,類間間距大。

K-means算法流程如下所示:
1)選取數據空間中的K個對象作爲初始中心,每個對象代表一個聚類中心;
2)對於樣本中的數據對象,根據它們與這些聚類中心的歐氏距離,按距離最近的準則將它們分到距離它們最近的聚類中心(最相似)所對應的類;
3)更新聚類中心:將每個類別中所有對象所對應的均值作爲該類別的聚類中心,計算目標函數的值;
4)判斷聚類中心和目標函數的值是否發生改變,若不變,則輸出結果,若改變,則返回2)。

下圖是簡單的聚類結果:K=3

     

python代碼實現:

#-*-coding:utf-8 -*-
from numpy import *
from math import sqrt

import sys
sys.path.append("C:/Users/Administrator/Desktop/k-means的python實現")
 
def loadData(fileName):
    data = []
    fr = open(fileName)
    for line in fr.readlines():
        curline = line.strip().split('\t')
        frline = map(float,curline)
        data.append(frline)
    return data
'''
#test
a = mat(loadData("C:/Users/Administrator/Desktop/k-means/testSet.txt"))
print a
'''
#claculate the distance
def distElud(vecA,vecB):
    return sqrt(sum(power((vecA - vecB),2)))

#initialization
def randCent(dataSet,k):
    n = shape(dataSet)[1]
    center = mat(zeros((k,n)))
    for j in range(n):
        rangeJ = float(max(dataSet[:,j]) - min(dataSet[:,j]))
        center[:,j] = min(dataSet[:,j]) + rangeJ * random.rand(k,1)
    return center
'''
#test
a = mat(loadData("C:/Users/Administrator/Desktop/k-means/testSet.txt"))
n = 3
b = randCent(a,3)
print b
'''
def kMeans(dataSet,k,dist = distElud,createCent = randCent):
    m = shape(dataSet)[0]
    clusterAssment = mat(zeros((m,2)))
    center = createCent(dataSet,k)
    clusterChanged = True
    while clusterChanged:
        clusterChanged = False
        for i in range(m):
            minDist = inf
            minIndex = -1
            for j in range(k):
                distJI = dist(dataSet[i,:],center[j,:])
                if distJI < minDist:
                    minDist = distJI
                    minIndex = j
            if clusterAssment[i,0] != minIndex:#判斷是否收斂
                clusterChanged = True
            clusterAssment[i,:] = minIndex,minDist ** 2
        print center
        for cent in range(k):#更新聚類中心
            dataCent = dataSet[nonzero(clusterAssment[:,0].A == cent)[0]]
            center[cent,:] = mean(dataCent,axis = 0)#axis是普通的將每一列相加,而axis=1表示的是將向量的每一行進行相加
    return center,clusterAssment
'''
#test
dataSet = mat(loadData("C:/Users/Administrator/Desktop/k-means/testSet.txt"))
k = 4
a = kMeans(dataSet,k)
print a
'''

matlab代碼實現:

%%%K-means
clear all
clc

%% initialization
mu1=[0 0 0];  
S1=[0.23 0 0;0 0.87 0;0 0 0.56]; 
data1=mvnrnd(mu1,S1,100);

%%第二類數據
mu2=[1.25 1.25 1.25];
S2=[0.23 0 0;0 0.87 0;0 0 0.56];
data2=mvnrnd(mu2,S2,100);

%第三個類數據
mu3=[-1.25 1.25 -1.25];
S3=[0.23 0 0;0 0.87 0;0 0 0.56];
data3=mvnrnd(mu3,S3,100);

mu4=[1.5 1.5 1.5];
S4=[0.23 0 0;0 0.87 0;0 0 0.56];
data4 =mvnrnd(mu4,S4,100);

%顯示數據
figure;
plot3(data1(:,1),data1(:,2),data1(:,3),'+');
title('原始數據');
hold on
plot3(data2(:,1),data2(:,2),data2(:,3),'r+');
plot3(data3(:,1),data3(:,2),data3(:,3),'g+');
plot3(data4(:,1),data4(:,2),data3(:,3),'y+');
grid on;


data=[data1;data2;data3;data4];   
[row,col] = size(data);
K = 4;
max_iter = 300;%%迭代次數
min_impro = 0.1;%%%%最小步長
display = 1;%%%判定條件
center = zeros(K,col);
U = zeros(K,col);
%% 初始化聚類中心
mi = zeros(col,1);
ma = zeros(col,1);
for i = 1:col
    mi(i,1) = min(data(:,i));
    ma(i,1) = max(data(:,i));
    center(:,i) = ma(i,1) - (ma(i,1) - mi(i,1)) * rand(K,1);
end

%% 開始迭代
for o = 1:max_iter
    %% 計算歐氏距離,用norm函數
    for i = 1:K
        dist{i} = [];
        for j = 1:row
            dist{i} = [dist{i};data(j,:) - center(i,:)];
        end
    end
    
    minDis = zeros(row,K);
    for i = 1:row
        tem = [];
        for j = 1:K
            tem = [tem norm(dist{j}(i,:))];
        end
        [nmin,index] = min(tem);
        minDis(i,index) = norm(dist{index}(i,:));
    end
    
    
    %% 更新聚類中心
     for i = 1:K
        for j = 1:col
            U(i,j) = sum(minDis(:,i).*data(:,j)) / sum(minDis(:,i));
        end
     end
     
     %% 判定
      if display
   end
   if o >1,
       if max(abs(U - center)) < min_impro;
           break;
       else
           center = U;
       end
   end
end

 %% 返回所屬的類別
 class = [];
 for i = 1:row
     dist = [];
     for j = 1:K
         dist = [dist norm(data(i,:) - U(j,:))];
     end
     [nmin,index] = min(dist);
     class = [class;data(i,:) index];
 end
  
 %% 顯示最後結果
[m,n] = size(class);
figure;
title('聚類結果');
hold on;
for i=1:row 
    if class(i,4)==1   
         plot3(class(i,1),class(i,2),class(i,3),'ro'); 
    elseif class(i,4)==2
         plot3(class(i,1),class(i,2),class(i,3),'go'); 
    elseif class(i,4) == 3
         plot3(class(i,1),class(i,2),class(i,3),'bo'); 
    else
        plot3(class(i,1),class(i,2),class(i,3),'yo'); 
    end
end
grid on;

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章