I have a dataset in .csv format which looks like this - data
x,y,z, label
2,1,3, A
5,3,1, B
6,2,2, C
9,5,3, B
2,3,4, A
4,1,4, A
I would like to apply k-mean clustering to the above dataset. As we see above the 3 dimension dataset(x-y-z). And after that, I would like to visualize the clustering in 3-dimension with a specific cluster label in diagram. Please let know if you need more details.
I have used for 2-dimension dataset as see below -
kmeans_labels = cluster.KMeans(n_clusters=5).fit_predict(data)
And plot the visualize for 2-dimension dataset,
plt.scatter(standard_embedding[:, 0], standard_embedding[:, 1], c=kmeans_labels, s=0.1, cmap='Spectral');
Similarly, I would like to plot 3-dimension clustering with label. Please let me know if you need more details.
CodePudding user response:
Could something like that be a good solution?
import numpy as np
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
data = np.array([[2,1,3], [5,3,1], [6,2,2], [9,5,3], [2,3,4], [4,1,4]])
cluster_count = 3
km = KMeans(cluster_count)
clusters = km.fit_predict(data)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
scatter = ax.scatter(data[:, 0], data[:, 1], data[:, 2], c=clusters, alpha=1)
labels = ["A", "B", "C"]
for i, label in enumerate(labels):
ax.text(km.cluster_centers_[i, 0], km.cluster_centers_[i, 1], km.cluster_centers_[i, 2], label)
ax.set_title("3D K-Means Clustering")
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_zlabel("z")
plt.show()
EDIT
If you want a legend instead, just do this:
import numpy as np
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
data = np.array([[2,1,3], [5,3,1], [6,2,2], [9,5,3], [2,3,4], [4,1,4]])
cluster_count = 3
km = KMeans(cluster_count)
clusters = km.fit_predict(data)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
scatter = ax.scatter(data[:, 0], data[:, 1], data[:, 2], c=clusters, alpha=1)
handles = scatter.legend_elements()[0]
ax.legend(title="Clusters", handles=handles, labels = ["A", "B", "C"])
ax.set_title("3D K-Means Clustering")
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_zlabel("z")
plt.show()