Mean Shift 算法原理及 Python 實現

 一、 Mean Shift 算法

 K-Means 算法最終的聚類效果受初始的聚類中心的影響,K-Means++ 算法未選擇較好的初始聚類中心提供了依據,但在 K-Means 算法中,聚類的類別個數 k 仍需要事先指定。對於類別個數未知的, K-Means 算法和 K-Means++ 算法很難將其進行精確求解。 Mean Shift 算法被提出用於解決聚類個數未知的情況。

 Mean Shift 算法又稱均值漂移算法,是基於聚類中心的聚類算法。實現不需要指定類別個數k,聚類中心是通過在給定區域中的均值來確定的,通過不斷更新聚類中心,直到最終的聚類中心不再改變。 Mean Shift 算法在聚類、圖像平滑、分割和視頻跟蹤等方面有廣泛的應用。

二、 Mean Shift 算法的原理

1.核函數

Mean Shift算法中引入核函數的目的是使得隨着樣本與被偏移點的距離不同,其偏移量對均值偏移向量的貢獻也不同

核函數的定義

X表示一個d維的歐式空間,x是該空間中的一個點x={x1,x2,x3⋯,xd},其中,x的模\left \| x\right \|^{2}=xx^{T},R表示實數域,如果一個函數K:X→R存在一個剖面函數k:\left [ 0,\infty\right ]\rightarrow R,即
                                                                            K\left ( x\right )=k(\left \| x\right \|^{2})
  並且滿足: 

  • k是非負的 
  • k是非增的 
  • k是分段連續的 

  那麼,函數K(x)就稱爲核函數。

常用的核函數

線性核:k(x,y)=x^{T}y

多項式核:k(x,y)=(x^{T}y)^{d}  ,d\geq 1爲多項式次數

高斯核:k(x,y)=exp(-\frac{\left \| x-y\right|^{2}}{2\sigma ^{2}})   ,\sigma > 0爲高斯核的帶寬

拉普拉斯核:k(x,y)=exp(-\frac{\left \| x-y\right|}{\sigma })   

Sigmoid核:k(x,y)=tanh(\beta x^{T}y+\theta )      ,tanh爲雙曲正切函數,\beta > 0,\theta < 0

在這裏我們使用高斯核函數,將他寫成下面的形式:

K(\frac{x_{1}-x_{2}}{h})=\frac{1}{\sqrt{2\pi}h}exp(-\frac{(x_{1}-x_{2})^{2}}{2h^{2}})

h爲帶寬,不同帶寬的核函數如下所示:

import matplotlib.pyplot as plt
import math

def cal_Gaussian(x, h=1):
    molecule = x * x
    denominator = 2 * h * h
    left = 1 / (math.sqrt(2 * math.pi) * h)
    return left * math.exp(-molecule / denominator)

x = []

for i in range(-20,20):
    x.append(i * 0.5);

score_1 = []
score_2 = []
score_3 = []

for i in x:
    score_1.append(cal_Gaussian(i,1))
    score_2.append(cal_Gaussian(i,2))
    score_3.append(cal_Gaussian(i,3))
 
plt.figure(figsize=(10,8), dpi=80)    
plt.plot(x, score_1, 'r--', label="h=1")
plt.plot(x, score_2, 'b--', label="h=2")
plt.plot(x, score_3, 'g--', label="h=3")

#顯示中文標題
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False

plt.legend(loc="upper right")
plt.title("高斯核函數")
plt.xlabel("x")
plt.ylabel("K")
plt.show()

從圖中可以看出,當 h 一定是,樣本點之間的距離越近,其核函數的值越大;當樣本點之間的距離相等時,隨着高斯核函數的帶寬 h 的增大,核函數的值在減少。

2.基本原理

基本的Mean Shift向量

對於給定的d維空間Rd中的n個樣本點xi,i=1,⋯,n,其Mean Shift向量的基本形式爲:

其中,Sh指的是一個半徑爲h的高維球區域,Sh的定義爲:

這樣的Mean Shift形式存在一個問題:在Sh的區域內,每一個點對x的貢獻是一樣的。而實際上,每一個樣本點對x的貢獻是不一樣的。

改進的Mean Shift向量

爲使每一個樣本點對x的貢獻不一樣,基本的Mean Shift向量形式中增加核函數,得如下改進的Mean Shift向量形式:

M_{h}(X)=\frac{\sum_{X^{i}\in S_{h}}[K(\frac{X^{i}-X}{h})\cdot (X^{i}-X)]}{\sum_{X^{i}\in S_{h}}[K(\frac{X^{i}-X}{h})]}

其中K(\frac{X^{i}-X}{h})爲高斯核函數,可以取Sh爲整個數據集範圍,Mean Shift向量Mh(x)是歸一化的概率密度梯度。

Mean Shift 算法的基本過程

聚類中心是通過在給定區域中的均值來確定的,通過不斷更新聚類中心,直到最終的聚類中心不再改變退出。

1.在指定的區域內計算偏移均值(如下圖的黃色的圈),並移動該點到偏移均值點處

2.重複上述的過程計算新的偏移均值,並移動到偏移均值點處

 

3.直到最終的聚類中心不再改變退出

Mean Shift算法的解釋

在Mean Shift算法中,實際上是利用了概率密度,求得概率密度的局部最優解。

對一個概率密度函數f(x),已知d維空間中n個採樣點xi,i=1,⋯,n,f(x)的核函數估計(也稱爲Parzen窗估計)爲:

Mean Shift向量的修正:

3.算法流程

  • 計算m_{h}(X)
  • X=m_{h}\left ( X \right )
  • 如果\left \| m_{h}\left ( X \right ) -X\right \|< \varepsilon,結束循環,否則,重複上述步驟。

 三、Mean Shift 算法實踐

# -*- coding: utf-8 -*-
"""
Created on Tue Apr  2 17:16:08 2019

@author: 2018061801
"""
import matplotlib.pyplot as plt

import math
import numpy as np

MIN_DISTANCE = 0.000001  # mini error

def load_data(path, feature_num=2):
    '''導入數據
    input:  path(string)文件的存儲位置
            feature_num(int)特徵的個數
    output: data(array)特徵
    '''
    f = open(path)  # 打開文件
    data = []
    for line in f.readlines():
        lines = line.strip().split("\t")
        data_tmp = []
        if len(lines) != feature_num:  # 判斷特徵的個數是否正確
            continue
        for i in range(feature_num):
            data_tmp.append(float(lines[i]))
        data.append(data_tmp)
    f.close()  # 關閉文件
    return data

def gaussian_kernel(distance, bandwidth):
    '''高斯核函數
    input:  distance(mat):歐式距離
            bandwidth(int):核函數的帶寬
    output: gaussian_val(mat):高斯函數值
    '''
    m = np.shape(distance)[0]  # 樣本個數
    right = np.mat(np.zeros((m, 1)))  # mX1的矩陣
    for i in range(m):
        right[i, 0] = (-0.5 * distance[i] * distance[i].T) / (bandwidth * bandwidth)
        right[i, 0] = np.exp(right[i, 0])
    left = 1 / (bandwidth * math.sqrt(2 * math.pi))
    
    gaussian_val = left * right
    return gaussian_val

def shift_point(point, points, kernel_bandwidth):
    '''計算均值漂移點
    input:  point(mat)需要計算的點
            points(array)所有的樣本點
            kernel_bandwidth(int)核函數的帶寬
    output: point_shifted(mat)漂移後的點
    '''
    points = np.mat(points)
    m = np.shape(points)[0]  # 樣本的個數
    # 計算距離
    point_distances = np.mat(np.zeros((m, 1)))
    for i in range(m):
        point_distances[i, 0] = euclidean_dist(point, points[i])
    
    # 計算高斯核        
    point_weights = gaussian_kernel(point_distances, kernel_bandwidth)  # mX1的矩陣
    
    # 計算分母
    all_sum = 0.0
    for i in range(m):
        all_sum += point_weights[i, 0]
    
    # 均值偏移
    point_shifted = point_weights.T * points / all_sum
    return point_shifted

def euclidean_dist(pointA, pointB):
    '''計算歐式距離
    input:  pointA(mat):A點的座標
            pointB(mat):B點的座標
    output: math.sqrt(total):兩點之間的歐式距離
    '''
    # 計算pointA和pointB之間的歐式距離
    total = (pointA - pointB) * (pointA - pointB).T
    return math.sqrt(total)  # 歐式距離

def group_points(mean_shift_points):
    '''計算所屬的類別
    input:  mean_shift_points(mat):漂移向量
    output: group_assignment(array):所屬類別
    '''
    group_assignment = []
    m, n = np.shape(mean_shift_points)
    index = 0
    index_dict = {}
    for i in range(m):
        item = []
        for j in range(n):
            item.append(str(("%5.2f" % mean_shift_points[i, j])))
        
        item_1 = "_".join(item)
        if item_1 not in index_dict:
            index_dict[item_1] = index
            index += 1
    
    for i in range(m):
        item = []
        for j in range(n):
            item.append(str(("%5.2f" % mean_shift_points[i, j])))

        item_1 = "_".join(item)
        group_assignment.append(index_dict[item_1])

    return group_assignment

def train_mean_shift(points, kenel_bandwidth=2):
    '''訓練Mean shift模型
    input:  points(array):特徵數據
            kenel_bandwidth(int):核函數的帶寬
    output: points(mat):特徵點
            mean_shift_points(mat):均值漂移點
            group(array):類別
    '''
    mean_shift_points = np.mat(points)
    max_min_dist = 1
    iteration = 0  # 訓練的代數
    m = np.shape(mean_shift_points)[0]  # 樣本的個數
    need_shift = [True] * m  # 標記是否需要漂移

    # 計算均值漂移向量
    while max_min_dist > MIN_DISTANCE:
        max_min_dist = 0
        iteration += 1
        print ("\titeration : " + str(iteration))
        for i in range(0, m):
            # 判斷每一個樣本點是否需要計算偏移均值
            if not need_shift[i]:
                continue
            p_new = mean_shift_points[i]
            p_new_start = p_new
            p_new = shift_point(p_new, points, kenel_bandwidth)  # 對樣本點進行漂移
            dist = euclidean_dist(p_new, p_new_start)  # 計算該點與漂移後的點之間的距離

            if dist > max_min_dist:
                max_min_dist = dist
            if dist < MIN_DISTANCE:  # 不需要移動
                need_shift[i] = False

            mean_shift_points[i] = p_new

    # 計算最終的group
    group = group_points(mean_shift_points)  # 計算所屬的類別
    
    return np.mat(points), mean_shift_points, group

def save_result(file_name, data):
    '''保存最終的計算結果
    input:  file_name(string):存儲的文件名
            data(mat):需要保存的文件
    '''
    f = open(file_name, "w")
    m, n = np.shape(data)
    for i in range(m):
        tmp = []
        for j in range(n):
            tmp.append(str(data[i, j]))
        f.write("\t".join(tmp) + "\n")
    f.close()
    

if __name__ == "__main__":
    # 導入數據集
    print ("----------1.load data ------------")
    data = load_data("D:/anaconda4.3/spyder_work/data5.txt", 2)
    # 訓練,h=2
    print ("----------2.training ------------")
    points, shift_points, cluster = train_mean_shift(data, 2)
    # 保存所屬的類別文件
    print ("----------3.1.save sub ------------")
    save_result("sub_1", np.mat(cluster))
    print ("----------3.2.save center ------------")
    # 保存聚類中心
    save_result("center", shift_points)    



f = open("D:/anaconda4.3/spyder_work/data5.txt")
x = []
y = []
for line in f.readlines():
    lines = line.strip().split("\t")
    if len(lines) == 2:
        x.append(float(lines[0]))
        y.append(float(lines[1]))
f.close()  

#顯示中文標題
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False

plt.figure(figsize=(10,8), dpi=80) 
plt.plot(x, y, 'b.', label="原始數據")
plt.title('未使用聚類算法')
plt.legend(loc="upper right")
plt.show()

cluster_x_0 = []
cluster_x_1 = []
cluster_x_2 = []
cluster_y_0 = []
cluster_y_1 = []
cluster_y_2 = []
N = len(data)
data = np.array(data)

f = open("D:/anaconda4.3/spyder_work/center.txt") 
center_x = []
center_y = []
for line in f.readlines():
    lines = line.strip().split("\t")
    if len(lines) == 2:
        center_x.append(lines[0])
        center_y.append(lines[1])
f.close() 
for i in range(N):
    if cluster[i]==0:
        cluster_x_0.append(data[i, 0])
        cluster_y_0.append(data[i, 1])
    elif cluster[i]==1:
        cluster_x_1.append(data[i, 0])
        cluster_y_1.append(data[i, 1])
    elif cluster[i]==2:
        cluster_x_2.append(data[i, 0])
        cluster_y_2.append(data[i, 1])
      
plt.figure(figsize=(10,8), dpi=80)
plt.plot(cluster_x_0, cluster_y_0,'y.',label="cluster_0")
plt.plot(cluster_x_1, cluster_y_1,'g.',label="cluster_1")
plt.plot(cluster_x_2, cluster_y_2,'b.',label="cluster_2")
plt.plot(center_x, center_y, '+m', label="mean point")        
plt.title('使用聚類算法')
plt.legend(loc="best")           
plt.show() 



結果:

----------1.load data ------------
----------2.training ------------
        iteration : 1
        iteration : 2
        iteration : 3
        iteration : 4
        iteration : 5
        iteration : 6
        iteration : 7
        iteration : 8
        iteration : 9
        iteration : 10
        iteration : 11
        iteration : 12
        iteration : 13
        iteration : 14
        iteration : 15
        iteration : 16
        iteration : 17
        iteration : 18
        iteration : 19
        iteration : 20
        iteration : 21
        iteration : 22
        iteration : 23
        iteration : 24
        iteration : 25
        iteration : 26
        iteration : 27
        iteration : 28
----------3.1.save sub ------------
----------3.2.save center ------------

參考文獻:

1.簡單易學的機器學習算法——Mean Shift聚類算法

2.meanshift算法簡介

3.周志華——機器學習

4.趙志勇——Python 機器學習算法

 


 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章