(十四)sklearn 均值漂移聚类

 代码大部分来自官方文档,可以直接运行

import numpy as np

# estimate_bandwidth用于设置带宽
from sklearn.cluster import MeanShift, estimate_bandwidth

# 生成测试数据所需要的库
from sklearn.datasets.samples_generator import make_blobs

# 以(1,1),(-1,-1),(1,-1)为中心生成10000个标准差为0.6的测试数据集
centers = np.array([[1, 1], [-1, -1], [1, -1]])
X, _ = make_blobs(n_samples=10000, centers=centers, cluster_std=0.6)

# 自动更新带宽
bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=500)

# 训练模型
# bin_seeding参数将只初始化离散化的种子,即减少初始的种子的数量,加速算法
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(X)
print(ms.cluster_centers_)
# [[ 0.93977984 -0.92092147]
# [ 0.9950939   0.95548339]
# [-0.94187918 -0.99804463]]

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章