Home > Enterprise >  K-Means centroids not visible in 3D clustering plot
K-Means centroids not visible in 3D clustering plot

Time:12-22

When i plot clustering result in 2D using this code:

from matplotlib import pyplot as plt
from sklearn.datasets import make_blobs
from sklearn.cluster import KMeans

# create 2d data 
x, label = make_blobs(n_samples=3000, n_features=2, centers=4,
                      cluster_std=3, random_state=42)

# cluster data
kmeans = KMeans(init="k-means  ", n_clusters=4, random_state=42)
kmeans.fit(x)

# plot clusters and centroids
fig = plt.figure(figsize=(10,5))
ax = fig.add_subplot(111)
ax.scatter(x[:,0],x[:,1], c=kmeans.labels_, cmap='viridis',
           edgecolor='k', s=40, alpha = 0.5)
ax.scatter(kmeans.cluster_centers_[:,0], kmeans.cluster_centers_[:,1],
           s = 300, c = 'r', marker='*', label = 'Centroid')
ax.set_title("2D Kmeans clustering")
ax.set_xlabel("X")
ax.set_ylabel("Y")   
plt.show()

I get the following output:

Output 1

But when i try to plot clustering result in 3D using this code:

from matplotlib import pyplot as plt
from sklearn.datasets import make_blobs
from sklearn.cluster import KMeans

# create 3d data 
x, label = make_blobs(n_samples=3000, n_features=3, centers=4,
                      cluster_std=3, random_state=42)

# cluster data
kmeans = KMeans(init="k-means  ", n_clusters=4, random_state=42)
kmeans.fit(x)

# plot clusters and centroids
fig = plt.figure(figsize=(10,5))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(x[:,0], x[:,1], x[:,2] ,c=kmeans.labels_, cmap='viridis',
           edgecolor='k', s=40, alpha = 0.5)
ax.scatter(kmeans.cluster_centers_[:,0], kmeans.cluster_centers_[:,1],
           kmeans.cluster_centers_[:,2], s = 300, c = 'r',
           marker='*', label = 'Centroid')
ax.set_title("3D Kmeans clustering")
ax.set_xlabel("X")
ax.set_ylabel("Y")   
ax.set_zlabel("Z") 
plt.show()

I get the following output:

Output 2

As you can see the centroids of each cluster are not visible. I would like to see the centroids stars in the 3d plot, how can i achieve this?

Thanks in advance for any help!

CodePudding user response:

The centroid stars are surrounded by the clouds of points, therefore are not easily visible. You can make them more visible by making the points smaller (with s=10 or smaller) and more transparent (with alpha=0.1 or smaller) as showed by the following code line:

ax.scatter(x[:,0], x[:,1], x[:,2] ,c=kmeans.labels_, cmap='viridis',
       edgecolor='k', s=10, alpha = 0.1)
  • Related