I was doing quite some research, but couldn't find a satisfying solution yet.
I'm trying to build a heatmap using seaborn. As my dataset is a bit volatile in a lower range (0-20) but reaches up to 7000 using only one color-scale for all of the data doesn't allow a good graphical interpretation. That's why I thought about using two scales, two different color-spectrums.
I would like to merge those two heatmaps into one:
What works so far is that I get both axes (scales) displayed in my plot, but when it comes to plotting the data only the last active axis is taken into account. The upper range is not considered.
CodePudding user response:
You need to create the two Axes
to plot beforehand, and then use the ax
argument when calling sns.heatmap
in order to tell seaborn which Axes
should have which colormap.
Example (using mock data):
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
values = np.random.uniform(0, 7000, size=(50, 50))
vmax = np.amax(values)
fig, (ax1, ax2) = plt.subplots(ncols=2)
sns.heatmap(values, ax=ax1, vmin=25, vmax=vmax, cmap="crest")
sns.heatmap(values, ax=ax2, vmin=0, vmax=25, cmap="flare")
for ax in (ax1, ax2):
ax.set_xlabel("Time")
ax.set_ylabel("Method")
plt.show()
Resulting figure:
CodePudding user response:
You can create a Colormap class that uses both other Colormaps
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
class SplitCMap(mpl.colors.Colormap):
def __init__(self, name, vmin, vsplit, vmax, N=256):
super().__init__(name, N)
self.lowcmap = mpl.cm.get_cmap('flare')
self.highcmap = mpl.cm.get_cmap('crest')
self.split_level = (vsplit-vmin) / (vmax-vmin)
self.scale_low = 1.0 / self.split_level
self.scale_high = 1.0 / (1.0 - self.split_level)
def mapcolor(self, v, **kwds):
if v < self.split_level:
return self.lowcmap(v * self.scale_low, **kwds)
return self.highcmap((v-self.split_level)*self.scale_high, **kwds)
def __call__(self, *args, **kwds):
if isinstance(args[0], (int, float)):
self.mapcolor(args[0], **kwds)
return [self.mapcolor(v, **kwds) for v in args[0] ]
df = pd.DataFrame(merged, classes)
vmax = np.amax(merged)
cmap = SplitCMap('split', 0, 25, vmax)
ax = sns.heatmap(df, cmap=cmap)
plt.xlabel("Time")
plt.ylabel("Method")
plt.show()