In python, I have sns heatmap with different groups of rows, where each group is colored differently. I would like to add a label for each color (each group of rows) preferably at the right side of the Y axis. For example, for the red group, I want a label to the right side of the rows called "Increasing Trend". For the purple group, I want a label to the right side of the rows called "Emerging Trend", and so on. Below is my heatmap:
Below is how I generated the heatmap with different colors:
agg_df=pd.DataFrame(agg, columns=a_list, index=agg_x)
print(agg_df)
agg_df_1=agg_df.copy()
agg_df_2=agg_df.copy()
agg_df_3=agg_df.copy()
agg_df_4=agg_df.copy()
#generate heatmap
agg_df_1.iloc[:,24:] = float('nan')
g=sns.heatmap(agg_df_1.transpose(), annot=False,fmt=".1f",annot_kws={"fontsize":6},cmap=pyplot.cm.Reds,xticklabels=1, yticklabels=1,cbar=False)
g.set_xticklabels(g.get_xticklabels(), fontsize = 8)
g.set_yticklabels(g.get_yticklabels(), fontsize = 8)
agg_df_2.iloc[:,:24]=float('nan')
agg_df_2.iloc[:,35:]=float('nan')
g=sns.heatmap(agg_df_2.transpose(), annot=False,fmt=".1f",annot_kws={"fontsize":6},cmap=pyplot.cm.Oranges,xticklabels=1, yticklabels=1,cbar=False)
g.set_xticklabels(g.get_xticklabels(), fontsize = 8)
g.set_yticklabels(g.get_yticklabels(), fontsize = 8)
agg_df_3.iloc[:,:35]=float('nan')
agg_df_3.iloc[:,44:]=float('nan')
g=sns.heatmap(agg_df_3.transpose(), annot=False,fmt=".1f",annot_kws={"fontsize":6},cmap=pyplot.cm.Purples,xticklabels=1, yticklabels=1,cbar=False)
g.set_xticklabels(g.get_xticklabels(), fontsize = 8)
g.set_yticklabels(g.get_yticklabels(), fontsize = 8)
agg_df_4.iloc[:,:44]=float('nan')
g=sns.heatmap(agg_df_4.transpose(), annot=False,fmt=".1f",annot_kws={"fontsize":6},cmap=pyplot.cm.Greens,xticklabels=1, yticklabels=1,cbar=False)
g.set_xticklabels(g.get_xticklabels(), fontsize = 8)
g.set_yticklabels(g.get_yticklabels(), fontsize = 8)
for label in g.get_yticklabels():
label.set_weight('bold')
for label in g.get_xticklabels():
label.set_weight('bold')
#g.set(ylabel='Decreasing Emerging Mix Increasing')
pyplot.show()
UPDATE: Alternatively, a legend for each color would be also a great idea. How can I achieve that?
CodePudding user response:
You could simplify the code via a loop, and set text via the yaxis-transform:
from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
# create some dummy data similar to the given
columns = 'H He Li Be B C N O F Ne Na Mg Al Si P S Cl Ar K Ca Sc Ti V Cr Mn Fe Co Ni Cu Zn Ga Ge As Se Br Kr Rb Sr Y Zr Nb Mo Tc Ru Rh Pd'.split()
idx = [f'Q{i}_{y}' for i in range(1, 5) for y in range(2011, 2020)]
df = pd.DataFrame(np.random.rand(len(idx), len(columns)), columns=columns, index=idx)
# list of places where the bands start and stop
stops = [0, 24, 35, 44, len(columns)]
# color maps for each band
cmaps = ['Reds', 'copper', 'Purples', 'Greens']
# labels for the bands
labels = ['Increasing Trend', 'Orange Trend', 'Emerging Trend', 'Green Trend']
fig, ax = plt.subplots(figsize=(12, 10))
sns.set_context(font_scale=0.7) # scale factor for all fonts
# create a loop using the begin and end of each band, the colors and the labels
for beg, end, cmap, label in zip(stops[:-1], stops[1:], cmaps, labels):
# mask setting 0 for the band to be plotted in this step
mask = np.repeat([[1, 0, 1]] * len(idx), [beg, end - beg, len(columns) - end], axis=1).T
# heatmap for this band
sns.heatmap(df.T, mask=mask, cmap=cmap, annot=False, cbar=False, ax=ax)
# add some text to the center right of this band
ax.text(1.01, (beg end) / 2, '\n'.join(label.split()), ha='left', va='center', transform=ax.get_yaxis_transform())
plt.tight_layout() # fit all text nicely into the plot
plt.show()
CodePudding user response:
For an additional legend, I managed to achieve a good result with this:
from matplotlib.lines import Line2D
custom_lines = [Line2D([0], [0], color='red', lw=4),
Line2D([0], [0], color='orange', lw=4),
Line2D([0], [0], color='purple', lw=4),
Line2D([0], [0], color='green', lw=4)]
g.legend(custom_lines, ['Increasing', 'Mix', 'Emerging', 'Decreasing'],loc='upper left', bbox_to_anchor=(1, 1.01))