I'm trying to visualize different groups within my data, on the axis of a heatmap (using sns.heatmap
). I want the ticks to be categorized according to a specific dictionary, with names and colors, and eventually to be presented in a legend.
I know this can be obtained by using sns.clustermap
but this function also clusters the values in the heatmap, which I don't want to happen.
Any idea how can I make such visualization in a heatmap?
Example
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import random
import seaborn as sns
data = pd.DataFrame(np.random.randn(10, 3), columns=(list('ABC')))
data['group'] = pd.Series(random.choices(['group_1', 'group_2', 'group_3'], weights=[1,1,1], k=len(data)))
data
A B C group
0 0.366822 0.583965 1.629740 group_3
1 0.557286 0.450663 0.255852 group_3
2 -0.265515 -0.153028 0.670448 group_3
3 0.132278 -0.226668 1.365583 group_3
4 0.595304 -0.577290 0.395477 group_2
5 -0.805420 0.168376 0.748649 group_1
6 0.105664 -0.568047 -0.281488 group_2
7 -0.046202 0.173409 -0.250321 group_1
8 -0.132696 -0.877354 0.086954 group_3
9 -0.843666 0.655146 -1.629453 group_2
lut = {'group_1': 'red', 'group_2': 'blue', 'group_3': 'green'}
row_colors = data['group'].map(lut)
data.drop(['group'], axis=1, inplace=True)
fig = sns.heatmap(data)
I aim to use row_colors
to visually indicate which indices correspond to which group, in the heatmap to be.
any help regarding this would be highly appreciated, either directly from seaborn, or somehow else :)
CodePudding user response:
You might try a clustermap without clusters:
g = sns.clustermap(data=data.drop(['group'], axis=1), row_colors=row_colors, row_cluster=False, col_cluster=False, dendrogram_ratio=0.05, cbar_pos=None)
Or you might explicetly draw little rectangles next to the heatmap.
In the code below, the drop
isn't executed inplace` to be able to use the groups column afterwards.
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
group_names = ['group_1', 'group_2', 'group_3']
data = pd.DataFrame(np.random.randn(10, 3), columns=(list('ABC')))
data['group'] = pd.Series(np.random.choice(group_names, p=[1/3, 1/3, 1/3], size=len(data)))
lut = {'group_1': 'red', 'group_2': 'blue', 'group_3': 'green'}
row_colors = data['group'].map(lut)
ax = sns.heatmap(data.drop(['group'], axis=1))
ax.tick_params(axis='y', which='major', pad=20, length=0) # extra padding to leave room for the row colors
ax.set_yticklabels(data['group'], rotation=0) # optionally use the groups as the tick labels
for i, color in enumerate(row_colors):
ax.add_patch(plt.Rectangle(xy=(-0.05, i), width=0.05, height=1, color=color, lw=0,
transform=ax.get_yaxis_transform(), clip_on=False))
plt.tight_layout()
plt.show()