在上一篇博文中介绍了KNN算法(https://blog.csdn.net/jodie123456/article/details/101595943),接下来继续介绍K-means算法:
K-means(K均值聚类):(属于无监督学习)
K-means算法是一种简单的迭代型聚类算法,采用聚类作为相似性指标,从而发现给定数据集中的K个类,并且每一个类的中心是根据所有值得平均值得到的,每一个类用聚类来描述。对于给定的一个数据集X,需要分得得类别K,选取欧式距离作为相似性指标,聚类目标是使得各类的聚类平方和最小,最小化下面公式:
其中n是数据集的大小,是聚类中心。k是要聚类的类别数。结合最小二乘法和拉格朗日原理,聚类中心为对应类别中各数据点的平均值,同时为了使得算法收敛,在迭代过程中,应使最终的聚类中心尽可能的不变。最后说下如何评估聚类的好坏,通俗的就是说类内间距小,类间间距大。
K-means算法流程如下所示:
1)选取数据空间中的K个对象作为初始中心,每个对象代表一个聚类中心;
2)对于样本中的数据对象,根据它们与这些聚类中心的欧氏距离,按距离最近的准则将它们分到距离它们最近的聚类中心(最相似)所对应的类;
3)更新聚类中心:将每个类别中所有对象所对应的均值作为该类别的聚类中心,计算目标函数的值;
4)判断聚类中心和目标函数的值是否发生改变,若不变,则输出结果,若改变,则返回2)。
下图是简单的聚类结果:K=3
python代码实现:
#-*-coding:utf-8 -*-
from numpy import *
from math import sqrt
import sys
sys.path.append("C:/Users/Administrator/Desktop/k-means的python实现")
def loadData(fileName):
data = []
fr = open(fileName)
for line in fr.readlines():
curline = line.strip().split('\t')
frline = map(float,curline)
data.append(frline)
return data
'''
#test
a = mat(loadData("C:/Users/Administrator/Desktop/k-means/testSet.txt"))
print a
'''
#claculate the distance
def distElud(vecA,vecB):
return sqrt(sum(power((vecA - vecB),2)))
#initialization
def randCent(dataSet,k):
n = shape(dataSet)[1]
center = mat(zeros((k,n)))
for j in range(n):
rangeJ = float(max(dataSet[:,j]) - min(dataSet[:,j]))
center[:,j] = min(dataSet[:,j]) + rangeJ * random.rand(k,1)
return center
'''
#test
a = mat(loadData("C:/Users/Administrator/Desktop/k-means/testSet.txt"))
n = 3
b = randCent(a,3)
print b
'''
def kMeans(dataSet,k,dist = distElud,createCent = randCent):
m = shape(dataSet)[0]
clusterAssment = mat(zeros((m,2)))
center = createCent(dataSet,k)
clusterChanged = True
while clusterChanged:
clusterChanged = False
for i in range(m):
minDist = inf
minIndex = -1
for j in range(k):
distJI = dist(dataSet[i,:],center[j,:])
if distJI < minDist:
minDist = distJI
minIndex = j
if clusterAssment[i,0] != minIndex:#判断是否收敛
clusterChanged = True
clusterAssment[i,:] = minIndex,minDist ** 2
print center
for cent in range(k):#更新聚类中心
dataCent = dataSet[nonzero(clusterAssment[:,0].A == cent)[0]]
center[cent,:] = mean(dataCent,axis = 0)#axis是普通的将每一列相加,而axis=1表示的是将向量的每一行进行相加
return center,clusterAssment
'''
#test
dataSet = mat(loadData("C:/Users/Administrator/Desktop/k-means/testSet.txt"))
k = 4
a = kMeans(dataSet,k)
print a
'''
matlab代码实现:
%%%K-means
clear all
clc
%% initialization
mu1=[0 0 0];
S1=[0.23 0 0;0 0.87 0;0 0 0.56];
data1=mvnrnd(mu1,S1,100);
%%第二类数据
mu2=[1.25 1.25 1.25];
S2=[0.23 0 0;0 0.87 0;0 0 0.56];
data2=mvnrnd(mu2,S2,100);
%第三个类数据
mu3=[-1.25 1.25 -1.25];
S3=[0.23 0 0;0 0.87 0;0 0 0.56];
data3=mvnrnd(mu3,S3,100);
mu4=[1.5 1.5 1.5];
S4=[0.23 0 0;0 0.87 0;0 0 0.56];
data4 =mvnrnd(mu4,S4,100);
%显示数据
figure;
plot3(data1(:,1),data1(:,2),data1(:,3),'+');
title('原始数据');
hold on
plot3(data2(:,1),data2(:,2),data2(:,3),'r+');
plot3(data3(:,1),data3(:,2),data3(:,3),'g+');
plot3(data4(:,1),data4(:,2),data3(:,3),'y+');
grid on;
data=[data1;data2;data3;data4];
[row,col] = size(data);
K = 4;
max_iter = 300;%%迭代次数
min_impro = 0.1;%%%%最小步长
display = 1;%%%判定条件
center = zeros(K,col);
U = zeros(K,col);
%% 初始化聚类中心
mi = zeros(col,1);
ma = zeros(col,1);
for i = 1:col
mi(i,1) = min(data(:,i));
ma(i,1) = max(data(:,i));
center(:,i) = ma(i,1) - (ma(i,1) - mi(i,1)) * rand(K,1);
end
%% 开始迭代
for o = 1:max_iter
%% 计算欧氏距离,用norm函数
for i = 1:K
dist{i} = [];
for j = 1:row
dist{i} = [dist{i};data(j,:) - center(i,:)];
end
end
minDis = zeros(row,K);
for i = 1:row
tem = [];
for j = 1:K
tem = [tem norm(dist{j}(i,:))];
end
[nmin,index] = min(tem);
minDis(i,index) = norm(dist{index}(i,:));
end
%% 更新聚类中心
for i = 1:K
for j = 1:col
U(i,j) = sum(minDis(:,i).*data(:,j)) / sum(minDis(:,i));
end
end
%% 判定
if display
end
if o >1,
if max(abs(U - center)) < min_impro;
break;
else
center = U;
end
end
end
%% 返回所属的类别
class = [];
for i = 1:row
dist = [];
for j = 1:K
dist = [dist norm(data(i,:) - U(j,:))];
end
[nmin,index] = min(dist);
class = [class;data(i,:) index];
end
%% 显示最后结果
[m,n] = size(class);
figure;
title('聚类结果');
hold on;
for i=1:row
if class(i,4)==1
plot3(class(i,1),class(i,2),class(i,3),'ro');
elseif class(i,4)==2
plot3(class(i,1),class(i,2),class(i,3),'go');
elseif class(i,4) == 3
plot3(class(i,1),class(i,2),class(i,3),'bo');
else
plot3(class(i,1),class(i,2),class(i,3),'yo');
end
end
grid on;