Home > Software engineering >  Plot a single cluster
Plot a single cluster

Time:10-08

I am working with HDBSCAN and I want to plot only one cluster of the data.

This is my current code:

import hdbscan
import pandas as pd
from sklearn.datasets import make_blobs

blobs, labels = make_blobs(n_samples=2000, n_features=10)

clusterer = hdbscan.HDBSCAN(min_cluster_size=15).fit(blobs)
color_palette = sns.color_palette('deep', 8)
cluster_colors = [color_palette[x] if x >= 0
                  else (0.5, 0.5, 0.5)
                  for x in clusterer.labels_]
cluster_member_colors = [sns.desaturate(x, p) for x, p in
                         zip(cluster_colors, clusterer.probabilities_)]
plt.scatter(blobs[:, 2], blobs[:, 5], s=50, linewidth=0, c=cluster_member_colors, alpha=0.25)
plt.show()

I know that the data has 3 clusters but how can I plot only one of them?

If I have a cluster point, how can I know which column of the pandas data frame corresponds to that point?

CodePudding user response:

I recommend to add all the relevant information to a pandas dataframe instead.

df = pd.DataFrame(blobs)

clusterer = hdbscan.HDBSCAN(min_cluster_size=15).fit(blobs)
df['cluster'] = clusterer.labels_
df['probability'] = clusterer.probabilities_

color_palette = sns.color_palette('deep', 8)
def get_member_color(row):
    if row['cluster'] >= 0:
        c = color_palette[int(row['cluster'])]
    else:
        c = (0.5, 0.5, 0.5)
    
    member_color = sns.desaturate(c, row['probability'])
    return member_color

df['member_color'] = df.apply(get_member_color, axis=1)

Now you can easily filter the rows after which cluster they belong to. For example, to plot all samples belonging to cluster 2 we can do:

df2 = df.loc[df['cluster'] == 2]
plt.scatter(df2.iloc[:, 2], df2.iloc[:, 5], s=50, linewidth=0, c=df2['member_color'], alpha=0.25)
plt.show()
  • Related