(十四)sklearn 均值漂移聚類

 代碼大部分來自官方文檔,可以直接運行

import numpy as np

# estimate_bandwidth用於設置帶寬
from sklearn.cluster import MeanShift, estimate_bandwidth

# 生成測試數據所需要的庫
from sklearn.datasets.samples_generator import make_blobs

# 以(1,1),(-1,-1),(1,-1)爲中心生成10000個標準差爲0.6的測試數據集
centers = np.array([[1, 1], [-1, -1], [1, -1]])
X, _ = make_blobs(n_samples=10000, centers=centers, cluster_std=0.6)

# 自動更新帶寬
bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=500)

# 訓練模型
# bin_seeding參數將只初始化離散化的種子,即減少初始的種子的數量,加速算法
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(X)
print(ms.cluster_centers_)
# [[ 0.93977984 -0.92092147]
# [ 0.9950939   0.95548339]
# [-0.94187918 -0.99804463]]

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章