Home > Enterprise >  Matplotlib scatter plot with different colors/label based on a category
Matplotlib scatter plot with different colors/label based on a category

Time:05-23

Given a 2d data with n points in space. And a list of categories with size 'k < n' and list of category labels to each data point (of course with same dimension as the data, that is, size n). Is it possible to plot the a scatterplot and display a legend without n different entries on legend box?

The data looks like this:

categories = {'fruit': 0, 'animal': 1}
data = np.array([ [1,1], [2,1], [0,1], [3,2], [3,3] ])
labels = [ 'fruit', 'fruit', 'animal', 'animal', 'fruit' ]

And a snippet to display my attempt

color_categories = [f'C{categories[l]}' for l in labels]
plt.scatter(data[:,0], data[:,1], c=[f'C{categories[l]}' for l in labels], label=labels)
plt.legend();

Using the code above the legend displays the list, not each category. I could reach an acceptable solution using a for loop in the categories and applying some "filtering" as in:

for k, v in categories.items():
    positions = [ i for i in range(len(labels)) if labels[i] == k ]
    points_to_plot = np.take(data, positions, axis=0)
    plt.scatter(points_to_plot[:,0], points_to_plot[:,1], label=k)
plt.legend()

But I really do not enjoy this approach. Is there a more direct way to handle this?

CodePudding user response:

You can use seaborn:

import seaborn as sns
import numpy as np

data = np.array([[1,1], [2,1], [0,1], [3,2], [3,3]])
labels = ['fruit', 'fruit', 'animal', 'animal', 'fruit']
sns.scatterplot(x=data[:, 0], y=data[:, 1], hue=labels)

It gives:

plot

CodePudding user response:

Here's a solution that uses only matplotlib (and numpy since that's what you used to define your original dataset). Basically, make a set of all of the categories and then pulls and plots the data that meets each category individually.

import numpy as np
import matplotlib.pyplot as plt

data = np.array([ [1,1], [2,1], [0,1], [3,2], [3,3] ])
labels = [ 'fruit', 'fruit', 'animal', 'animal', 'fruit' ]

for l in set(labels):
    x = [data[i][0] for i in range(len(data)) if labels[i] == l]
    y = [data[i][1] for i in range(len(data)) if labels[i] == l]
    plt.plot(x,y,'o',label = l)
    # alternately, 
    #plt.scatter(x,y,label = l)

plt.legend()

plot

If you want your legend to be alphabetical, you could change the for loop to:

for l in sorted(set(labels)):
  • Related