Home > database >  How to make group colorbar xticks yticks in Seaborn or Matplotplib
How to make group colorbar xticks yticks in Seaborn or Matplotplib

Time:09-28

Given an array normal_data with a label mgroup, the following heatmap is generated.

enter image description here

But, I wonder whether it is possible using Matplotlib or Seaborn to have a xtick-ytick color bar (The rectangle shape of color blue, orange, light green, dark blue, dark green) to represent a group membership, as shown above?

import seaborn as sns
from matplotlib import pyplot as plt
import numpy as np
mgroup=np.array([1,1,1,2,2,3,3,3,3,4,5,5])
normal_data = np.random.randn(12, 12)

ax = sns.heatmap(normal_data, center=0)
plt.show()

CodePudding user response:

You can use plt.Rectangle to draw colored rectangles. Setting clip_on=False allows drawing outside the main axes. The colored x and y grouping

CodePudding user response:

For the group values, we use np.unique() to get the array and frequency, and use that to create a stacked graph. The created graphs are placed on the right and top of the heat map, with new axes added to them. As for the labels, I decided their positions manually. I am getting there by hand to complete this, so it may not be the best way. Also, some of the code may be redundant.

import seaborn as sns
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd

mgroup=np.array([1,1,1,2,2,3,3,3,3,4,5,5])
normal_data = np.random.randn(12, 12)
u,counts = np.unique(mgroup, return_counts=True)
df = pd.DataFrame({'value':counts},index=u)

fig, axs = plt.subplots(1, 1, figsize=(4,4), sharey=True)# , sharey=True
cbar_ax = fig.add_axes([0.1,0.0,0.8,0.05])
g = sns.heatmap(normal_data, center=0, cbar_ax=cbar_ax, cbar_kws={"orientation": "horizontal"}, ax=axs)
#g.set_aspect('auto')
df_ax = fig.add_axes([0.9,0.12,0.05,0.80])
d = df.T.plot(kind='bar', stacked=True, width=0.7, legend=False, ax=df_ax)
d.spines['top'].set_visible(False)
d.spines['bottom'].set_visible(False)
d.spines['right'].set_visible(False)
d.spines['left'].set_visible(False)
df_ax.set_xticks([])
df_ax.set_yticks([])

df_ax2 = fig.add_axes([0.12,0.85,0.82,0.05])
d = df.T.plot(kind='barh', stacked=True, width=0.7, legend=False, ax=df_ax2)
d.spines['top'].set_visible(False)
d.spines['bottom'].set_visible(False)
d.spines['right'].set_visible(False)
d.spines['left'].set_visible(False)
df_ax2.set_xticks([])
df_ax2.set_yticks([])

df_ax.text(1.0, 0.20, 'Group5', transform=fig.transFigure)
df_ax.text(1.0, 0.35, 'Group4', transform=fig.transFigure)
df_ax.text(1.0, 0.55, 'Group3', transform=fig.transFigure)
df_ax.text(1.0, 0.70, 'Group2', transform=fig.transFigure)
df_ax.text(1.0, 0.80, 'Group1', transform=fig.transFigure)
plt.show()

enter image description here

  • Related