K最近邻数字聚类

# -*- coding: utf-8 -*
from sklearn.datasets import load_digits
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
from scipy.stats import mode
import numpy as np
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
import seaborn as sns
from sklearn.manifold import TSNE # 非线性嵌入算法模块

digits = load_digits()
# print(digits.data.shape)
kmeans = KMeans(n_clusters=10, random_state=0)
clusters = kmeans.fit_predict(digits.data)
# print(kmeans.cluster_centers_.shape)

fig, ax = plt.subplots(2, 5, figsize=(8, 3))
centers = kmeans.cluster_centers_.reshape(10, 8, 8)
for axi, center in zip(ax.flat, centers):
axi.set(xticks=[], yticks=[])
axi.imshow(center, interpolation='nearest', cmap=plt.cm.binary)

labels = np.zeros_like(clusters) # 建立类似clusters的全0数组
# print('clusters:\n', clusters)
# print('clusters.shape:\n', clusters.shape)
# print('labels:\n', labels)
for i in range(10): # 纠正labels,让每个簇估计出的数字服从出现最大概率的数字
mask = (clusters == i) # 对于等于某个数字的簇形成的mask
# print('mask:\n', mask)
# print('digits.target[mask]:\n', digits.target[mask]) # 筛选出mask罩出的判断值
# 用mask选出并纠正labels里每个簇里的各个值,化为target里被mask罩出(估计)的大多数值一样
labels[mask] = mode(digits.target[mask])[0] # mode(): 返回传入数组/矩阵中最常出现的成员以及出现的次数,如果多个成员出现次数一样多,返回值小的那个
# print('labels[mask]:\n', labels[mask]) # 改变后
print('accuracy_score is: ', accuracy_score(digits.target, labels))
fig2, ax2 = plt.subplots()
mat = confusion_matrix(digits.target, labels)
sns.heatmap(mat.T, square=True, annot=True, fmt='d', cbar=False, xticklabels=digits.target_names,
yticklabels=digits.target_names)
plt.xlabel('true label')
plt.ylabel('predicted label')

# 预处理,投影数据到二维
# 初始化,默认为random。取值为random为随机初始化,取值为pca为利用PCA进行初始化(常用),取值为numpy数组时必须shape=(n_samples, n_components)
tsne = TSNE(n_components=2, init='pca', random_state=0)
digits_proj = tsne.fit_transform(digits.data)
# 计算簇
kmeans = KMeans(n_clusters=10, random_state=0)
clusters = kmeans.fit_predict(digits_proj)
# 对应统一标签
labels = np.zeros_like(clusters)
for i in range(10):
mask = (clusters == i)
labels[mask] = mode(digits.target[mask])[0]
print('accuracy_score is: ', accuracy_score(digits.target, labels))
plt.show()

留下评论

通过 WordPress.com 设计一个这样的站点
从这里开始