Home > Net >  Scatterplot set color and legend based on data point label for matplotlib
Scatterplot set color and legend based on data point label for matplotlib

Time:08-10

I'm trying to do a scatter plot and color the points with ID. And I want to make sure my legends displays all the IDs with respected colors. Here's my code:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure

x = np.random.uniform(0, 100, 2000)
y = np.random.uniform(0, 50, 2000)
ID = np.random.randint(0,100,2000)

fig, ax = plt.subplots(figsize=(10, 8),dpi = 80)
scatter = ax.scatter(x,
                     y,
                    c = ID)

ax.set_xlabel('x')
ax.set_ylabel('y')
ax.legend(*scatter.legend_elements(),
          loc="center left", 
          title='ID', 
          bbox_to_anchor=(1, 0.5)
         )    
ax.ticklabel_format(useOffset=False)
ax.tick_params(axis = 'x',labelrotation = 45)

But the legends only gives me a list of IDs with even interval instead of all the actual IDs used in the data.

What I have:

How do I set the correct legend so it displays ALL the IDs used in coloring the points? Thanks.

CodePudding user response:

matpotlib is currently inferring you colors to be on a continuous scale instead of a categorical one.

You can pass the unique IDs you want a label to be created for into the num argument of .legend_elements to do this:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure

x = np.random.uniform(0, 100, 2000)
y = np.random.uniform(0, 50, 2000)
ID = np.random.randint(0,100,2000)

fig, ax = plt.subplots(figsize=(10, 8),dpi = 80)
scatter = ax.scatter(x,
                     y,
                    c = ID)

ax.set_xlabel('x')
ax.set_ylabel('y')
ax.legend(*scatter.legend_elements(num=list(np.unique(ID))),
          loc="center left", 
          title='ID', 
          bbox_to_anchor=(1, 0.5),
          ncol=5
         )    
ax.ticklabel_format(useOffset=False)
ax.tick_params(axis = 'x',labelrotation = 45)

enter image description here


Alternatively, you can iterate over your unique IDs and add each a scatter for each unique ID. This way matplotlib will infer your IDs as unique entries on your plot.

You'll additionally need to segment a sequential colormap to achieve a non-repeating color and pair those colors against the unique IDs.

import numpy as np
from matplotlib.cm import get_cmap
import matplotlib.pyplot as plt

rng = np.random.default_rng(0)

x = rng.uniform(0, 100, 2000)
y = rng.uniform(0, 50, 2000)
ID = rng.integers(0,100,2000) 
unique_ids = np.unique(ID)

fig, ax = plt.subplots(figsize=(12, 6))
cmap = get_cmap('viridis', len(unique_ids))

for _id, color in zip(unique_ids, cmap.colors):
    mask = (ID == _id)
    ax.scatter(x[mask], y[mask], label=_id, color=color)

ax.legend(ncol=5, bbox_to_anchor=(1, 1), loc='upper left', title='Tracker IDs')
fig.tight_layout()

enter image description here

  • Related