Home > Net >  How to plot imshow legend based on colormap?
How to plot imshow legend based on colormap?

Time:12-10

I am trying to plot a legend based on the colormap of my imshow plot.

I could do that by specifying the colors and labels I want with matplotlib.mpatches and matplotlib.colors, as I saw in some examples, but I just would like to follow the colors from a cmap. For example:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Data
df = pd.DataFrame({"a":[0.1, 0.3, 0.4, 0.2, 0.5, 0.6],
                   "z":[5, 10, 15, 20, 25, 30],
                   "class":[1, 2, 3, 1, 3, 3]})

# 'Class' 2D cluster 
cluster = np.repeat(np.expand_dims(df["class"].values, 1), 10, 1)

f, ax = plt.subplots(nrows=1, ncols=2, figsize=(12,10))
ax[0].plot(df["a"], df["z"])
ax[1].imshow(cluster, interpolation='none', cmap="Paired",
             aspect='auto', vmin=1, vmax=3, extent=[0, 1, df["z"].min(), df["z"].max()])

ax[0].set_ylim(df["z"].min(), df["z"].max())

ax[1].legend()

Here I would like to plot the legend based on the cmap "Paired" from my imshow plot, and with this code I just get "No handles with labels found to put in legend". Anyone could help me?

I need something like this: enter image description here

CodePudding user response:

You need to create a custom legend, using the same cmap and norm (vmin and vmax) as in ax.imshow(). For each unique "class", a handle is created using a dummy rectangle.

Here is an example:

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

# Data
df = pd.DataFrame({"a": [0.1, 0.3, 0.4, 0.2, 0.5, 0.6],
                   "z": [5, 10, 15, 20, 25, 30],
                   "class": [1, 2, 3, 1, 3, 3]})

cluster = df["class"].values.reshape(-1, 1)

f, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 10))
ax[0].plot(df["a"], df["z"])
norm = plt.Normalize(vmin=df["class"].min(), vmax=df["class"].max())
cmap = plt.get_cmap('Paired')
ax[1].imshow(cluster, interpolation='none', cmap=cmap, norm=norm,
             aspect='auto', extent=[0, 1, df["z"].min(), df["z"].max()])

ax[0].set_ylim(df["z"].min(), df["z"].max())

handles = [plt.Rectangle((0, 0), 0, 0, color=cmap(norm(i)), label=i) for i in df["class"].unique()]
ax[1].legend(handles=handles, title='class')
plt.tight_layout()
plt.show()

imshow with custom legend

  • Related