Home > Enterprise >  Get centroid of scatter plot
Get centroid of scatter plot

Time:06-12

I have generated this scatter plot via the plotting of the first two PCA elements from a feature extraction...PCA1 and PCA2.

enter image description here

The plot shown above is for 3 classes and with PCA1 (x-axis) vs PCA2 (y-axis). I have generated the plot as follow:

target_names = ['class_1', 'class_2', 'class_3']
plt.figure(figsize=(11, 8))
Xt = pca.fit_transform(X)
plot = plt.scatter(Xt[:,0], Xt[:,1], c=y, cmap=plt.cm.jet,
            s=30, linewidths=0, alpha=0.7)
#centers = kmeans.cluster_centers_
#plt.scatter(centers[:, 0], centers[:, 1], c=['black', 'green', 'red'], marker='^', s=100, #alpha=0.5);
plt.legend(handles=plot.legend_elements()[0], labels=list(target_names))
plt.show()

I wanted to know how to correctly get the centroid of each of the classes from the plot.

Here are the first few columns of the data:

Xt1 Xt2 y
-107.988187 -23.70121   1
-128.578852 -20.222378  1
-124.522967 -25.298283  1
-96.222918  -25.028239  1
-95.152954  -23.94496   1
-113.275804 -26.563129  1
-101.803    -24.22359   1
-94.662469  -22.94211   1
-104.118882 -24.037226  1
439.765098  -101.532469 2
50.100362   -34.278841  2
-69.229603  62.178599   2
-60.915475  53.296491   2
64.797364   91.991527   2
-112.815192 0.263505    0
-91.287067  -25.207217  0
-74.181941  -2.457892   0
-83.273718  -0.608004   0
-100.881393 -22.387571  0
-107.861711 -15.848869  0
-85.866992  -18.79126   0
-53.96314   -28.885316  0
-59.195432  -3.373361   0

Any help will be greatly appreciated.

CodePudding user response:

Assuming that y is an array of labels corresponding to the rows of X (and therefore Xt), we can create a data frame with Xt[:, :2] and y and then use groupby('y') to aggregate the mean values for Xt[:, 0] and Xt[:, 1] for each value of y:

import pandas as pd

df = pd.DataFrame(Xt[:, :2], columns=['Xt1', 'Xt2'])
df['y'] = y
df.groupby('y').mean()

This will produce the means of Xt[:, 0] and Xt[:, 1] for each label in y, which are the centroid coordinates of each label in y in the first two principal components of the data.

With the snippet of data that the OP provided, the following script computes the centroids and overlays them on the plot as 'X's of the same color as the data:

df = pd.DataFrame(Xt[:, :2], columns=['Xt1', 'Xt2'])
df['y'] = y
df_centroid = df.groupby('y').mean().reset_index()

target_names = ['class_1', 'class_2', 'class_3']
plt.figure(figsize=(11, 8))
plot = plt.scatter(Xt[:, 0], Xt[:, 1], c=y, cmap=plt.cm.jet,
                   s=30, linewidths=0, alpha=0.5)
# Overlays the centroids on the plot as 'X'
plt.scatter(df_centroid.Xt1, df_centroid.Xt2, marker='x', s=60,
            c=df_centroid.y, cmap=plt.cm.jet)
plt.legend(handles=plot.legend_elements()[0], labels=list(target_names))
plt.show()

enter image description here

  • Related