Given an array normal_data
with a label mgroup
, the following heatmap is generated.
But, I wonder whether it is possible using Matplotlib or Seaborn to have a xtick-ytick color bar (The rectangle shape of color blue, orange, light green, dark blue, dark green) to represent a group membership, as shown above?
import seaborn as sns
from matplotlib import pyplot as plt
import numpy as np
mgroup=np.array([1,1,1,2,2,3,3,3,3,4,5,5])
normal_data = np.random.randn(12, 12)
ax = sns.heatmap(normal_data, center=0)
plt.show()
CodePudding user response:
You can use plt.Rectangle
to draw colored rectangles. Setting clip_on=False
allows drawing outside the main axes. The
CodePudding user response:
For the group values, we use np.unique() to get the array and frequency, and use that to create a stacked graph. The created graphs are placed on the right and top of the heat map, with new axes added to them. As for the labels, I decided their positions manually. I am getting there by hand to complete this, so it may not be the best way. Also, some of the code may be redundant.
import seaborn as sns
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
mgroup=np.array([1,1,1,2,2,3,3,3,3,4,5,5])
normal_data = np.random.randn(12, 12)
u,counts = np.unique(mgroup, return_counts=True)
df = pd.DataFrame({'value':counts},index=u)
fig, axs = plt.subplots(1, 1, figsize=(4,4), sharey=True)# , sharey=True
cbar_ax = fig.add_axes([0.1,0.0,0.8,0.05])
g = sns.heatmap(normal_data, center=0, cbar_ax=cbar_ax, cbar_kws={"orientation": "horizontal"}, ax=axs)
#g.set_aspect('auto')
df_ax = fig.add_axes([0.9,0.12,0.05,0.80])
d = df.T.plot(kind='bar', stacked=True, width=0.7, legend=False, ax=df_ax)
d.spines['top'].set_visible(False)
d.spines['bottom'].set_visible(False)
d.spines['right'].set_visible(False)
d.spines['left'].set_visible(False)
df_ax.set_xticks([])
df_ax.set_yticks([])
df_ax2 = fig.add_axes([0.12,0.85,0.82,0.05])
d = df.T.plot(kind='barh', stacked=True, width=0.7, legend=False, ax=df_ax2)
d.spines['top'].set_visible(False)
d.spines['bottom'].set_visible(False)
d.spines['right'].set_visible(False)
d.spines['left'].set_visible(False)
df_ax2.set_xticks([])
df_ax2.set_yticks([])
df_ax.text(1.0, 0.20, 'Group5', transform=fig.transFigure)
df_ax.text(1.0, 0.35, 'Group4', transform=fig.transFigure)
df_ax.text(1.0, 0.55, 'Group3', transform=fig.transFigure)
df_ax.text(1.0, 0.70, 'Group2', transform=fig.transFigure)
df_ax.text(1.0, 0.80, 'Group1', transform=fig.transFigure)
plt.show()