Home > front end >  How to extend the color palette in matplotlib?
How to extend the color palette in matplotlib?


I coded the following function:

def plot_cumulative_dynamic_auc(risk_score, label, color=None):

    auc, mean_auc = cumulative_dynamic_auc(y_trn, y_test, risk_score, times)
    plt.plot(times, auc, marker="o", color=color, label=label)
    plt.xlabel("days from enrollment")
    plt.ylabel("time-dependent AUC")
    plt.axhline(mean_auc, color=color, linestyle="--")

And then the for-loop:

for i, col in enumerate(num_columns):
    plot_cumulative_dynamic_auc(X_test.iloc[:, i], col, color="C{}".format(i))
    ret = concordance_index_ipcw(y_trn, y_test, X_test.iloc[:, i], tau=times[-1])

As the for loop iterates over num_columns which has 40 variables, the standard palette only offers 10 colors. However, I want to have every variable its own color. Is there a way to code it also being flexible when it comes to the number of variables?

CodePudding user response:

Matplotlib offers tab20, which is too restrictive for your case. Since you have a lot of lines, a possible solution is to use a colormap, or more than one. Take a look at the enter image description here

As you can see, the first and last lines uses similar colors, so if the colormap is cyclic (such as hsv) it might be a good idea to restrict the discretization range, for example discr = np.linspace(0, 0.75, N).

Creating colors from multiple colormaps

Matplotlib offers many diverging colormaps. We can use them to create a combination of colors, for example:

import numpy as np
from matplotlib import pyplot as plt
import matplotlib.cm as cm

# compile a list of colormaps
colormaps = [cm.Reds_r, cm.Blues_r, cm.Greens_r, cm.Purples_r]
N = 40 # number of lines
x = np.array([0, 1])
theta = np.linspace(0, np.pi / 2, N)

# extract the following number of colors for each colormap
n_cols_per_cm = int(np.ceil(N / len(colormaps)))
# discretize the colormap. Note the upper limit of 0.75, so we
# avoid too white-ish colors
discr = np.linspace(0, 0.75, n_cols_per_cm)

# extract the colors
colors = np.zeros((n_cols_per_cm * len(colormaps), 4))
for i, cmap in enumerate(colormaps):
    colors[i * n_cols_per_cm : (i   1) * n_cols_per_cm, :] = cmap(discr)

f, ax = plt.subplots()
for i, t in enumerate(theta):
    ax.plot(x, np.tan(t) * x, color=colors[i])
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)

enter image description here

  • Related