一、 Mean Shift 算法
K-Means 算法最終的聚類效果受初始的聚類中心的影響,K-Means++ 算法未選擇較好的初始聚類中心提供了依據,但在 K-Means 算法中,聚類的類別個數 k 仍需要事先指定。對於類別個數未知的, K-Means 算法和 K-Means++ 算法很難將其進行精確求解。 Mean Shift 算法被提出用於解決聚類個數未知的情況。
Mean Shift 算法又稱均值漂移算法,是基於聚類中心的聚類算法。實現不需要指定類別個數k,聚類中心是通過在給定區域中的均值來確定的,通過不斷更新聚類中心,直到最終的聚類中心不再改變。 Mean Shift 算法在聚類、圖像平滑、分割和視頻跟蹤等方面有廣泛的應用。
二、 Mean Shift 算法的原理
1.核函數
Mean Shift算法中引入核函數的目的是使得隨着樣本與被偏移點的距離不同,其偏移量對均值偏移向量的貢獻也不同
核函數的定義
X表示一個d維的歐式空間,x是該空間中的一個點x={x1,x2,x3⋯,xd},其中,x的模,R表示實數域,如果一個函數K:X→R存在一個剖面函數,即
並且滿足:
- k是非負的
- k是非增的
- k是分段連續的
那麼,函數K(x)就稱爲核函數。
常用的核函數
線性核:
多項式核: ,爲多項式次數
高斯核: ,爲高斯核的帶寬
拉普拉斯核:
Sigmoid核: ,tanh爲雙曲正切函數,
在這裏我們使用高斯核函數,將他寫成下面的形式:
h爲帶寬,不同帶寬的核函數如下所示:
import matplotlib.pyplot as plt
import math
def cal_Gaussian(x, h=1):
molecule = x * x
denominator = 2 * h * h
left = 1 / (math.sqrt(2 * math.pi) * h)
return left * math.exp(-molecule / denominator)
x = []
for i in range(-20,20):
x.append(i * 0.5);
score_1 = []
score_2 = []
score_3 = []
for i in x:
score_1.append(cal_Gaussian(i,1))
score_2.append(cal_Gaussian(i,2))
score_3.append(cal_Gaussian(i,3))
plt.figure(figsize=(10,8), dpi=80)
plt.plot(x, score_1, 'r--', label="h=1")
plt.plot(x, score_2, 'b--', label="h=2")
plt.plot(x, score_3, 'g--', label="h=3")
#顯示中文標題
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.legend(loc="upper right")
plt.title("高斯核函數")
plt.xlabel("x")
plt.ylabel("K")
plt.show()
從圖中可以看出,當 h 一定是,樣本點之間的距離越近,其核函數的值越大;當樣本點之間的距離相等時,隨着高斯核函數的帶寬 h 的增大,核函數的值在減少。
2.基本原理
基本的Mean Shift向量
對於給定的d維空間Rd中的n個樣本點xi,i=1,⋯,n,其Mean Shift向量的基本形式爲:
其中,Sh指的是一個半徑爲h的高維球區域,Sh的定義爲:
這樣的Mean Shift形式存在一個問題:在Sh的區域內,每一個點對x的貢獻是一樣的。而實際上,每一個樣本點對x的貢獻是不一樣的。
改進的Mean Shift向量
爲使每一個樣本點對x的貢獻不一樣,基本的Mean Shift向量形式中增加核函數,得如下改進的Mean Shift向量形式:
其中爲高斯核函數,可以取Sh爲整個數據集範圍,Mean Shift向量Mh(x)是歸一化的概率密度梯度。
Mean Shift 算法的基本過程
聚類中心是通過在給定區域中的均值來確定的,通過不斷更新聚類中心,直到最終的聚類中心不再改變退出。
1.在指定的區域內計算偏移均值(如下圖的黃色的圈),並移動該點到偏移均值點處
2.重複上述的過程計算新的偏移均值,並移動到偏移均值點處
3.直到最終的聚類中心不再改變退出
Mean Shift算法的解釋
在Mean Shift算法中,實際上是利用了概率密度,求得概率密度的局部最優解。
對一個概率密度函數f(x),已知d維空間中n個採樣點xi,i=1,⋯,n,f(x)的核函數估計(也稱爲Parzen窗估計)爲:
Mean Shift向量的修正:
3.算法流程
- 計算
- 令
- 如果,結束循環,否則,重複上述步驟。
三、Mean Shift 算法實踐
# -*- coding: utf-8 -*-
"""
Created on Tue Apr 2 17:16:08 2019
@author: 2018061801
"""
import matplotlib.pyplot as plt
import math
import numpy as np
MIN_DISTANCE = 0.000001 # mini error
def load_data(path, feature_num=2):
'''導入數據
input: path(string)文件的存儲位置
feature_num(int)特徵的個數
output: data(array)特徵
'''
f = open(path) # 打開文件
data = []
for line in f.readlines():
lines = line.strip().split("\t")
data_tmp = []
if len(lines) != feature_num: # 判斷特徵的個數是否正確
continue
for i in range(feature_num):
data_tmp.append(float(lines[i]))
data.append(data_tmp)
f.close() # 關閉文件
return data
def gaussian_kernel(distance, bandwidth):
'''高斯核函數
input: distance(mat):歐式距離
bandwidth(int):核函數的帶寬
output: gaussian_val(mat):高斯函數值
'''
m = np.shape(distance)[0] # 樣本個數
right = np.mat(np.zeros((m, 1))) # mX1的矩陣
for i in range(m):
right[i, 0] = (-0.5 * distance[i] * distance[i].T) / (bandwidth * bandwidth)
right[i, 0] = np.exp(right[i, 0])
left = 1 / (bandwidth * math.sqrt(2 * math.pi))
gaussian_val = left * right
return gaussian_val
def shift_point(point, points, kernel_bandwidth):
'''計算均值漂移點
input: point(mat)需要計算的點
points(array)所有的樣本點
kernel_bandwidth(int)核函數的帶寬
output: point_shifted(mat)漂移後的點
'''
points = np.mat(points)
m = np.shape(points)[0] # 樣本的個數
# 計算距離
point_distances = np.mat(np.zeros((m, 1)))
for i in range(m):
point_distances[i, 0] = euclidean_dist(point, points[i])
# 計算高斯核
point_weights = gaussian_kernel(point_distances, kernel_bandwidth) # mX1的矩陣
# 計算分母
all_sum = 0.0
for i in range(m):
all_sum += point_weights[i, 0]
# 均值偏移
point_shifted = point_weights.T * points / all_sum
return point_shifted
def euclidean_dist(pointA, pointB):
'''計算歐式距離
input: pointA(mat):A點的座標
pointB(mat):B點的座標
output: math.sqrt(total):兩點之間的歐式距離
'''
# 計算pointA和pointB之間的歐式距離
total = (pointA - pointB) * (pointA - pointB).T
return math.sqrt(total) # 歐式距離
def group_points(mean_shift_points):
'''計算所屬的類別
input: mean_shift_points(mat):漂移向量
output: group_assignment(array):所屬類別
'''
group_assignment = []
m, n = np.shape(mean_shift_points)
index = 0
index_dict = {}
for i in range(m):
item = []
for j in range(n):
item.append(str(("%5.2f" % mean_shift_points[i, j])))
item_1 = "_".join(item)
if item_1 not in index_dict:
index_dict[item_1] = index
index += 1
for i in range(m):
item = []
for j in range(n):
item.append(str(("%5.2f" % mean_shift_points[i, j])))
item_1 = "_".join(item)
group_assignment.append(index_dict[item_1])
return group_assignment
def train_mean_shift(points, kenel_bandwidth=2):
'''訓練Mean shift模型
input: points(array):特徵數據
kenel_bandwidth(int):核函數的帶寬
output: points(mat):特徵點
mean_shift_points(mat):均值漂移點
group(array):類別
'''
mean_shift_points = np.mat(points)
max_min_dist = 1
iteration = 0 # 訓練的代數
m = np.shape(mean_shift_points)[0] # 樣本的個數
need_shift = [True] * m # 標記是否需要漂移
# 計算均值漂移向量
while max_min_dist > MIN_DISTANCE:
max_min_dist = 0
iteration += 1
print ("\titeration : " + str(iteration))
for i in range(0, m):
# 判斷每一個樣本點是否需要計算偏移均值
if not need_shift[i]:
continue
p_new = mean_shift_points[i]
p_new_start = p_new
p_new = shift_point(p_new, points, kenel_bandwidth) # 對樣本點進行漂移
dist = euclidean_dist(p_new, p_new_start) # 計算該點與漂移後的點之間的距離
if dist > max_min_dist:
max_min_dist = dist
if dist < MIN_DISTANCE: # 不需要移動
need_shift[i] = False
mean_shift_points[i] = p_new
# 計算最終的group
group = group_points(mean_shift_points) # 計算所屬的類別
return np.mat(points), mean_shift_points, group
def save_result(file_name, data):
'''保存最終的計算結果
input: file_name(string):存儲的文件名
data(mat):需要保存的文件
'''
f = open(file_name, "w")
m, n = np.shape(data)
for i in range(m):
tmp = []
for j in range(n):
tmp.append(str(data[i, j]))
f.write("\t".join(tmp) + "\n")
f.close()
if __name__ == "__main__":
# 導入數據集
print ("----------1.load data ------------")
data = load_data("D:/anaconda4.3/spyder_work/data5.txt", 2)
# 訓練,h=2
print ("----------2.training ------------")
points, shift_points, cluster = train_mean_shift(data, 2)
# 保存所屬的類別文件
print ("----------3.1.save sub ------------")
save_result("sub_1", np.mat(cluster))
print ("----------3.2.save center ------------")
# 保存聚類中心
save_result("center", shift_points)
f = open("D:/anaconda4.3/spyder_work/data5.txt")
x = []
y = []
for line in f.readlines():
lines = line.strip().split("\t")
if len(lines) == 2:
x.append(float(lines[0]))
y.append(float(lines[1]))
f.close()
#顯示中文標題
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.figure(figsize=(10,8), dpi=80)
plt.plot(x, y, 'b.', label="原始數據")
plt.title('未使用聚類算法')
plt.legend(loc="upper right")
plt.show()
cluster_x_0 = []
cluster_x_1 = []
cluster_x_2 = []
cluster_y_0 = []
cluster_y_1 = []
cluster_y_2 = []
N = len(data)
data = np.array(data)
f = open("D:/anaconda4.3/spyder_work/center.txt")
center_x = []
center_y = []
for line in f.readlines():
lines = line.strip().split("\t")
if len(lines) == 2:
center_x.append(lines[0])
center_y.append(lines[1])
f.close()
for i in range(N):
if cluster[i]==0:
cluster_x_0.append(data[i, 0])
cluster_y_0.append(data[i, 1])
elif cluster[i]==1:
cluster_x_1.append(data[i, 0])
cluster_y_1.append(data[i, 1])
elif cluster[i]==2:
cluster_x_2.append(data[i, 0])
cluster_y_2.append(data[i, 1])
plt.figure(figsize=(10,8), dpi=80)
plt.plot(cluster_x_0, cluster_y_0,'y.',label="cluster_0")
plt.plot(cluster_x_1, cluster_y_1,'g.',label="cluster_1")
plt.plot(cluster_x_2, cluster_y_2,'b.',label="cluster_2")
plt.plot(center_x, center_y, '+m', label="mean point")
plt.title('使用聚類算法')
plt.legend(loc="best")
plt.show()
結果:
----------1.load data ------------
----------2.training ------------
iteration : 1
iteration : 2
iteration : 3
iteration : 4
iteration : 5
iteration : 6
iteration : 7
iteration : 8
iteration : 9
iteration : 10
iteration : 11
iteration : 12
iteration : 13
iteration : 14
iteration : 15
iteration : 16
iteration : 17
iteration : 18
iteration : 19
iteration : 20
iteration : 21
iteration : 22
iteration : 23
iteration : 24
iteration : 25
iteration : 26
iteration : 27
iteration : 28
----------3.1.save sub ------------
----------3.2.save center ------------
參考文獻:
3.周志華——機器學習
4.趙志勇——Python 機器學習算法