# -*- coding: utf-8 -*-
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.datasets.samples_generator import make_blobs
from sklearn.metrics import pairwise_distances_argmin
sns.set()
def find_clusters(X, n_clusters, rseed=2):
rng = np.random.RandomState(rseed)
# 随机选择簇中心点
print('rng.permutation(X.shape[0])[:n_clusters]:\n', rng.permutation(X.shape[0])[:n_clusters])
i = rng.permutation(X.shape[0])[:n_clusters] # 随机排序X.shape[0](300个点,取出前nclusters个)
centers = X[i] # 随机选出这四个点坐标
print('centers:\n', centers)
while True: # 一直执行
# 基于最近中心指定标签
labels = pairwise_distances_argmin(X, centers)
# print('labels= ', labels) # 是一个表示分类的列表
# 根据各个分类里的点平均值找到新的中心
new_centers = np.array([X[labels == i].mean(0) for i in range(n_clusters)]) # .mean(0)跨行求平均
# print('new_centers:\n', new_centers)
if np.all(centers == new_centers):
break
centers = new_centers
return centers, labels
X, y_true = make_blobs(n_samples=300, centers=4, cluster_std=0.60, random_state=0)
centers, labels = find_clusters(X, 4)
plt.scatter(X[:, 0], X[:, 1], c=labels,
s=50, cmap='viridis')
plt.show()