機器學習(十四)——K-Means聚類

理解聚類:
在這裏插入圖片描述
本質:
在這裏插入圖片描述

1. K-Means做聚類

理解k-Means過程

例:對下圖做聚類
在這裏插入圖片描述
先隨機生成k箇中心點,如下圖,k=3
在這裏插入圖片描述
然後給定一個初始化分
在這裏插入圖片描述
計算每個cluster的x,y的均值,使得每個cluster都產生一個均值中心點。
然後根據距離調整劃分情況,如下圖,上圖中的三角形離中心五角星更近,所以把它劃分成五角星,
在這裏插入圖片描述
新加入一個五角星後,中心自然會再次向新加入的方向移動一點,中心移動後,可能又會加入新的,因爲可能又會有三角形離五角星的中心比三角形的中心更近,如下圖,又有兩個
在這裏插入圖片描述
新的兩個三角形加入後,中心又會調整…一直迭代,直至中心點不再變化或者沒有新的點加入,
在這裏插入圖片描述

上述例子的流程就是K-Means

流程

在這裏插入圖片描述
上圖中的第二個公式我們也可以看的出來,求加和再除以個數,就是求中心點的;第一個公式是算點與每個中心點的最小距離,所以它是用來分類的。

缺點

K-Means開始會隨機生成k箇中心點,完一隨機生成的中心點距離特別近

2. 其他做聚類的方法

在這裏插入圖片描述
在這裏插入圖片描述
在這裏插入圖片描述
例:第一步:隨機第一個中心點k,
第二步:算每個樣本到k的距離,假設又A,B,C三個樣本到k的距離分別爲5,4,1;把距離算出概率,把概率作爲座標軸,則0~50是A的50%的概率,50~90是B的40%,90~100是C的10%
第三步:隨機一個0~100的數,如果這個數在0~50之間,則選A作爲另一箇中心點,如果在50~90之間,則中心點爲B,如果在90~100之間,則中心點爲C

解決了k-means的k箇中心點隨機情況,k-means++使得初始的k箇中心點分佈均勻些

3. 選擇適當的聚類數 k

在這裏插入圖片描述
上圖中的橫座標是k,縱座標是Error
Error=每一個類的MSE加和,每一個類的MSE=每個樣本和中心點的距離平方加和

如上圖箭頭所指的地方,也就是拐點處就是我們要的k最適當的地方。
原因:

距離A=上一次的Error(縱座標) - 這一次的Error(縱座標)
距離B= 這一次的 - 下一次的,
而拐點處就是距離A-距離B最大的地方
說白了,就是找到一個點,這個點之後不管k怎麼增加,對Error的影響都不是那麼明顯

在這裏插入圖片描述
在這裏插入圖片描述
在這裏插入圖片描述

4. 代碼實現

# -*- coding:utf-8 -*-

import numpy as np
import matplotlib.pyplot as plt
import sklearn.datasets as ds
import matplotlib.colors
from matplotlib import font_manager
from sklearn.cluster import KMeans
from sklearn.cluster import MiniBatchKMeans


def expand(a, b):
    d = (b - a) * 0.1
    return a-d, b+d


if __name__ == "__main__":
    N = 400
    centers = 4
    # 創建聚類的模擬數據400條,兩個特徵,4個k,有y的原因是對比我們評估的效果
    data, y = ds.make_blobs(N, n_features=2, centers=centers, random_state=2)
    # 與上面不同的是方差,方差越大,數據點越分散,方差越小,數據越密集,四個類的數據密集程度不一樣
    data2, y2 = ds.make_blobs(N, n_features=2, centers=centers, cluster_std=(1, 2.5, 0.5, 2), random_state=2)
    # 也是生成四個類別的數據,第一個類別取data第一個類別的全部,第二個類別取data第二個類別前50條數據,第三個類別取data第三個類別前20條,第四個類別取data第四個類別前5條
    data3 = np.vstack((data[y == 0][:], data[y == 1][:50], data[y == 2][:20], data[y == 3][:5]))
    y3 = np.array([0] * 100 + [1] * 50 + [2] * 20 + [3] * 5)

    cls = KMeans(n_clusters=4, init='k-means++')
    # 聚類並且打標籤
    y_hat = cls.fit_predict(data)
    y2_hat = cls.fit_predict(data2)
    y3_hat = cls.fit_predict(data3)

    m = np.array(((1, 1), (1, 3)))
    data_r = data.dot(m)
    y_r_hat = cls.fit_predict(data_r)

    myfont = font_manager.FontProperties(fname="/usr/share/fonts/cjkuni-uming/uming.ttc", size=18)

    matplotlib.rcParams['font.sans-serif'] = ['SimHei']
    matplotlib.rcParams['axes.unicode_minus'] = False
    # 四個類別數據點的顏色
    cm = matplotlib.colors.ListedColormap(list('rgbm'))

    plt.figure(figsize=(9, 10), facecolor='w')
    plt.subplot(421)
    plt.title("原始數據",fontproperties=myfont)
    # 畫三點圖
    plt.scatter(data[:, 0], data[:, 1], c=y, s=30, cmap=cm, edgecolors='none')
    # 根據數據找到座標軸的最大數和最小數
    x1_min, x2_min = np.min(data, axis=0)
    x1_max, x2_max = np.max(data, axis=0)
    x1_min, x1_max = expand(x1_min, x1_max)
    x2_min, x2_max = expand(x2_min, x2_max)
    # 畫座標軸
    plt.xlim((x1_min, x1_max))
    plt.ylim((x2_min, x2_max))
    plt.grid(True)

    plt.subplot(422)
    plt.title('KMeans++聚類',fontproperties=myfont)
    plt.scatter(data[:, 0], data[:, 1], c=y_hat, s=30, cmap=cm, edgecolors='none')
    plt.xlim((x1_min, x1_max))
    plt.ylim((x2_min, x2_max))
    plt.grid(True)

    plt.subplot(423)
    plt.title('旋轉後數據', fontproperties=myfont)
    plt.scatter(data_r[:, 0], data_r[:, 1], c=y, s=30, cmap=cm, edgecolors='none')
    x1_min, x2_min = np.min(data_r, axis=0)
    x1_max, x2_max = np.max(data_r, axis=0)
    x1_min, x1_max = expand(x1_min, x1_max)
    x2_min, x2_max = expand(x2_min, x2_max)
    plt.xlim((x1_min, x1_max))
    plt.ylim((x2_min, x2_max))
    plt.grid(True)

    plt.subplot(424)
    plt.title('旋轉後KMeans++聚類',fontproperties=myfont)
    plt.scatter(data_r[:, 0], data_r[:, 1], c=y_r_hat, s=30, cmap=cm, edgecolors='none')
    plt.xlim((x1_min, x1_max))
    plt.ylim((x2_min, x2_max))
    plt.grid(True)

    plt.subplot(425)
    plt.title('方差不相等數據', fontproperties=myfont)
    plt.scatter(data2[:, 0], data2[:, 1], c=y2, s=30, cmap=cm, edgecolors='none')
    x1_min, x2_min = np.min(data2, axis=0)
    x1_max, x2_max = np.max(data2, axis=0)
    x1_min, x1_max = expand(x1_min, x1_max)
    x2_min, x2_max = expand(x2_min, x2_max)
    plt.xlim((x1_min, x1_max))
    plt.ylim((x2_min, x2_max))
    plt.grid(True)

    plt.subplot(426)
    plt.title('方差不相等KMeans++聚類', fontproperties=myfont)
    plt.scatter(data2[:, 0], data2[:, 1], c=y2_hat, s=30, cmap=cm, edgecolors='none')
    plt.xlim((x1_min, x1_max))
    plt.ylim((x2_min, x2_max))
    plt.grid(True)

    plt.subplot(427)
    plt.title('數量不相等數據', fontproperties=myfont)
    plt.scatter(data3[:, 0], data3[:, 1], s=30, c=y3, cmap=cm, edgecolors='none')
    x1_min, x2_min = np.min(data3, axis=0)
    x1_max, x2_max = np.max(data3, axis=0)
    x1_min, x1_max = expand(x1_min, x1_max)
    x2_min, x2_max = expand(x2_min, x2_max)
    plt.xlim((x1_min, x1_max))
    plt.ylim((x2_min, x2_max))
    plt.grid(True)

    plt.subplot(428)
    plt.title('數量不相等KMeans++聚類', fontproperties=myfont)
    plt.scatter(data3[:, 0], data3[:, 1], c=y3_hat, s=30, cmap=cm, edgecolors='none')
    plt.xlim((x1_min, x1_max))
    plt.ylim((x2_min, x2_max))
    plt.grid(True)

    plt.tight_layout(2, rect=(0, 0, 1, 0.97))
    plt.suptitle('數據分佈對KMeans聚類的影響', fontproperties=myfont)
    # https://github.com/matplotlib/matplotlib/issues/829
    # plt.subplots_adjust(top=0.92)

    plt.savefig('cluster_kmeans')
    plt.show()

在這裏插入圖片描述

應用案例:降維

對圖片進行降維


# -*- coding: utf-8 -*-

from PIL import Image
import numpy as np
from sklearn.cluster import KMeans
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D


def restore_image(cb, cluster, shape):
    row, col, dummy = shape
    image = np.empty((row, col, 3))
    index = 0
    for r in range(row):
        for c in range(col):
            image[r, c] = cb[cluster[index]]
            index += 1
    return image


def show_scatter(a):
    N = 10
    print('原始數據:\n', a)
    density, edges = np.histogramdd(a, bins=[N,N,N], range=[(0,1), (0,1), (0,1)])
    density /= density.max()
    x = y = z = np.arange(N)
    d = np.meshgrid(x, y, z)

    fig = plt.figure(1, facecolor='w')
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(d[1], d[0], d[2], c='r', s=100*density, marker='o', depthshade=True)
    ax.set_xlabel(u'紅色分量')
    ax.set_ylabel(u'綠色分量')
    ax.set_zlabel(u'藍色分量')
    plt.title(u'圖像顏色三維頻數分佈', fontsize=20)

    plt.figure(2, facecolor='w')
    den = density[density > 0]
    den = np.sort(den)[::-1]
    t = np.arange(len(den))
    plt.plot(t, den, 'r-', t, den, 'go', lw=2)
    plt.title(u'圖像顏色頻數分佈', fontsize=18)
    plt.grid(True)

    plt.show()


if __name__ == '__main__':
    # 中文有關
    matplotlib.rcParams['font.sans-serif'] = [u'SimHei']
    matplotlib.rcParams['axes.unicode_minus'] = False

    num_vq = 256
    im = Image.open('./Lena.png')     # flower2.png(200)/lena.png(50)
    image = np.array(im).astype(np.float) / 255
    image = image[:, :, :3]  # 只要rgb三個維度,不要透明度
    image_v = image.reshape((-1, 3))  # 拉伸
    model = KMeans(num_vq)  # 256個類別
    show_scatter(image_v)

    N = image_v.shape[0]    # 圖像像素總數
    # 選擇足夠多的樣本(如1000個),計算聚類中心
    # 隨機選擇1000個像素點
    idx = np.random.randint(0, N, size=1000)
    image_sample = image_v[idx]
    model.fit(image_sample)
    c = model.predict(image_v)  # 聚類結果
    print('聚類結果:\n', c)
    print('聚類中心:\n', model.cluster_centers_)

    plt.figure(figsize=(15, 8), facecolor='w')
    plt.subplot(121)
    plt.axis('off')
    plt.title(u'原始圖片', fontsize=18)
    plt.imshow(image)
    # plt.savefig('1.png')

    plt.subplot(122)
    vq_image = restore_image(model.cluster_centers_, c, image.shape)
    plt.axis('off')
    plt.title(u'矢量量化後圖片:%d色' % num_vq, fontsize=20)
    plt.imshow(vq_image)
    # plt.savefig('2.png')

    plt.tight_layout(1.2)
    plt.show()
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章