Home > OS >  How to add multiple labels for multiple groups of rows in sns heatmap on right side of Y axis?
How to add multiple labels for multiple groups of rows in sns heatmap on right side of Y axis?

Time:08-17

In python, I have sns heatmap with different groups of rows, where each group is colored differently. I would like to add a label for each color (each group of rows) preferably at the right side of the Y axis. For example, for the red group, I want a label to the right side of the rows called "Increasing Trend". For the purple group, I want a label to the right side of the rows called "Emerging Trend", and so on. Below is my heatmap:

enter image description here

Below is how I generated the heatmap with different colors:

agg_df=pd.DataFrame(agg, columns=a_list, index=agg_x)
print(agg_df)



agg_df_1=agg_df.copy()
agg_df_2=agg_df.copy()
agg_df_3=agg_df.copy()
agg_df_4=agg_df.copy()

#generate heatmap
agg_df_1.iloc[:,24:] = float('nan')
g=sns.heatmap(agg_df_1.transpose(), annot=False,fmt=".1f",annot_kws={"fontsize":6},cmap=pyplot.cm.Reds,xticklabels=1, yticklabels=1,cbar=False)
g.set_xticklabels(g.get_xticklabels(), fontsize = 8)
g.set_yticklabels(g.get_yticklabels(), fontsize = 8)

agg_df_2.iloc[:,:24]=float('nan')
agg_df_2.iloc[:,35:]=float('nan')
g=sns.heatmap(agg_df_2.transpose(), annot=False,fmt=".1f",annot_kws={"fontsize":6},cmap=pyplot.cm.Oranges,xticklabels=1, yticklabels=1,cbar=False)
g.set_xticklabels(g.get_xticklabels(), fontsize = 8)
g.set_yticklabels(g.get_yticklabels(), fontsize = 8)


agg_df_3.iloc[:,:35]=float('nan')
agg_df_3.iloc[:,44:]=float('nan')
g=sns.heatmap(agg_df_3.transpose(), annot=False,fmt=".1f",annot_kws={"fontsize":6},cmap=pyplot.cm.Purples,xticklabels=1, yticklabels=1,cbar=False)
g.set_xticklabels(g.get_xticklabels(), fontsize = 8)
g.set_yticklabels(g.get_yticklabels(), fontsize = 8)

agg_df_4.iloc[:,:44]=float('nan')
g=sns.heatmap(agg_df_4.transpose(), annot=False,fmt=".1f",annot_kws={"fontsize":6},cmap=pyplot.cm.Greens,xticklabels=1, yticklabels=1,cbar=False)
g.set_xticklabels(g.get_xticklabels(), fontsize = 8)
g.set_yticklabels(g.get_yticklabels(), fontsize = 8)

for label in g.get_yticklabels():
  label.set_weight('bold')

for label in g.get_xticklabels():
  label.set_weight('bold')

#g.set(ylabel='Decreasing   Emerging      Mix           Increasing')


pyplot.show()

UPDATE: Alternatively, a legend for each color would be also a great idea. How can I achieve that?

CodePudding user response:

You could simplify the code via a loop, and set text via the yaxis-transform:

from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

# create some dummy data similar to the given
columns = 'H He Li Be B C N O F Ne Na Mg Al Si P S Cl Ar K Ca Sc Ti V Cr Mn Fe Co Ni Cu Zn Ga Ge As Se Br Kr Rb Sr Y Zr Nb Mo Tc Ru Rh Pd'.split()
idx = [f'Q{i}_{y}' for i in range(1, 5) for y in range(2011, 2020)]
df = pd.DataFrame(np.random.rand(len(idx), len(columns)), columns=columns, index=idx)

# list of places where the bands start and stop
stops = [0, 24, 35, 44, len(columns)] 
# color maps for each band
cmaps = ['Reds', 'copper', 'Purples', 'Greens']
# labels for the bands
labels = ['Increasing Trend', 'Orange Trend', 'Emerging Trend', 'Green Trend']

fig, ax = plt.subplots(figsize=(12, 10))
sns.set_context(font_scale=0.7)  # scale factor for all fonts
# create a loop using the begin and end of each band, the colors and the labels
for beg, end, cmap, label in zip(stops[:-1], stops[1:], cmaps, labels):
    # mask setting 0 for the band to be plotted in this step
    mask = np.repeat([[1, 0, 1]] * len(idx), [beg, end - beg, len(columns) - end], axis=1).T
    # heatmap for this band
    sns.heatmap(df.T, mask=mask, cmap=cmap, annot=False, cbar=False, ax=ax)
    # add some text to the center right of this band
    ax.text(1.01, (beg   end) / 2, '\n'.join(label.split()), ha='left', va='center', transform=ax.get_yaxis_transform())
plt.tight_layout()  # fit all text nicely into the plot
plt.show()

4 combined heatmaps

CodePudding user response:

For an additional legend, I managed to achieve a good result with this:

from matplotlib.lines import Line2D
custom_lines = [Line2D([0], [0], color='red', lw=4),
                Line2D([0], [0], color='orange', lw=4),
                Line2D([0], [0], color='purple', lw=4),
                Line2D([0], [0], color='green', lw=4)]

g.legend(custom_lines, ['Increasing', 'Mix', 'Emerging', 'Decreasing'],loc='upper left', bbox_to_anchor=(1, 1.01))

enter image description here

  • Related