I'm trying to do a scatter plot and color the points with ID. And I want to make sure my legends displays all the IDs with respected colors. Here's my code:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
x = np.random.uniform(0, 100, 2000)
y = np.random.uniform(0, 50, 2000)
ID = np.random.randint(0,100,2000)
fig, ax = plt.subplots(figsize=(10, 8),dpi = 80)
scatter = ax.scatter(x,
y,
c = ID)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.legend(*scatter.legend_elements(),
loc="center left",
title='ID',
bbox_to_anchor=(1, 0.5)
)
ax.ticklabel_format(useOffset=False)
ax.tick_params(axis = 'x',labelrotation = 45)
But the legends only gives me a list of IDs with even interval instead of all the actual IDs used in the data.
How do I set the correct legend so it displays ALL the IDs used in coloring the points? Thanks.
CodePudding user response:
matpotlib
is currently inferring you colors to be on a continuous scale instead of a categorical one.
You can pass the unique IDs you want a label to be created for into the num
argument of .legend_elements
to do this:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
x = np.random.uniform(0, 100, 2000)
y = np.random.uniform(0, 50, 2000)
ID = np.random.randint(0,100,2000)
fig, ax = plt.subplots(figsize=(10, 8),dpi = 80)
scatter = ax.scatter(x,
y,
c = ID)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.legend(*scatter.legend_elements(num=list(np.unique(ID))),
loc="center left",
title='ID',
bbox_to_anchor=(1, 0.5),
ncol=5
)
ax.ticklabel_format(useOffset=False)
ax.tick_params(axis = 'x',labelrotation = 45)
Alternatively, you can iterate over your unique IDs and add each a scatter for each unique ID. This way matplotlib
will infer your IDs as unique entries on your plot.
You'll additionally need to segment a sequential colormap to achieve a non-repeating color and pair those colors against the unique IDs.
import numpy as np
from matplotlib.cm import get_cmap
import matplotlib.pyplot as plt
rng = np.random.default_rng(0)
x = rng.uniform(0, 100, 2000)
y = rng.uniform(0, 50, 2000)
ID = rng.integers(0,100,2000)
unique_ids = np.unique(ID)
fig, ax = plt.subplots(figsize=(12, 6))
cmap = get_cmap('viridis', len(unique_ids))
for _id, color in zip(unique_ids, cmap.colors):
mask = (ID == _id)
ax.scatter(x[mask], y[mask], label=_id, color=color)
ax.legend(ncol=5, bbox_to_anchor=(1, 1), loc='upper left', title='Tracker IDs')
fig.tight_layout()