Home > other >  Two ranges/ color-scales within one seaborn heatmap
Two ranges/ color-scales within one seaborn heatmap

Time:01-26

I was doing quite some research, but couldn't find a satisfying solution yet.

I'm trying to build a heatmap using seaborn. As my dataset is a bit volatile in a lower range (0-20) but reaches up to 7000 using only one color-scale for all of the data doesn't allow a good graphical interpretation. That's why I thought about using two scales, two different color-spectrums.

I would like to merge those two heatmaps into one:

enter image description here

What works so far is that I get both axes (scales) displayed in my plot, but when it comes to plotting the data only the last active axis is taken into account. The upper range is not considered.

heatmap with two scalesheatmap with two different colormaps

CodePudding user response:

You need to create the two Axes to plot beforehand, and then use the ax argument when calling sns.heatmap in order to tell seaborn which Axes should have which colormap.

Example (using mock data):

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

values = np.random.uniform(0, 7000, size=(50, 50))
vmax = np.amax(values)

fig, (ax1, ax2) = plt.subplots(ncols=2)

sns.heatmap(values, ax=ax1, vmin=25, vmax=vmax, cmap="crest")
sns.heatmap(values, ax=ax2, vmin=0, vmax=25, cmap="flare")

for ax in (ax1, ax2):
    ax.set_xlabel("Time")
    ax.set_ylabel("Method")

plt.show()

Resulting figure:

enter image description here

CodePudding user response:

You can create a Colormap class that uses both other Colormaps

import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt

class SplitCMap(mpl.colors.Colormap):
  def __init__(self, name, vmin, vsplit, vmax, N=256):
    super().__init__(name, N)
    self.lowcmap = mpl.cm.get_cmap('flare')
    self.highcmap = mpl.cm.get_cmap('crest')
    self.split_level = (vsplit-vmin) / (vmax-vmin)
    self.scale_low = 1.0 / self.split_level
    self.scale_high = 1.0 / (1.0 - self.split_level)
  def mapcolor(self, v, **kwds):
    if v < self.split_level:
      return self.lowcmap(v * self.scale_low, **kwds)
    return self.highcmap((v-self.split_level)*self.scale_high, **kwds)
  def __call__(self, *args, **kwds):
    if isinstance(args[0], (int, float)):
      self.mapcolor(args[0], **kwds)
    return [self.mapcolor(v, **kwds) for v in args[0] ]

df = pd.DataFrame(merged, classes)

vmax = np.amax(merged)

cmap = SplitCMap('split', 0, 25, vmax)
ax = sns.heatmap(df, cmap=cmap)

plt.xlabel("Time")
plt.ylabel("Method")

plt.show()
  •  Tags:  
  • Related