Home > front end >  Seaborn heatmap - row and column statistics to display
Seaborn heatmap - row and column statistics to display

Time:09-17

Is it possible to add row and column statistics on the edges of a Seaborn heatmap?

So for each row on the right hand side I want to display the row mean (for each month), and at the bottom edge for year, I want to show the column means for each column.

enter image description here

CodePudding user response:

If you are working with a dataframe like this:

df = pd.DataFrame({'date': pd.date_range(start = '1949-01-01', end = '1960-12-01', freq = 'MS')})
df['value'] = np.random.randint(100, 600, len(df))
          date  value
0   1949-01-01    202
1   1949-02-01    535
2   1949-03-01    448
3   1949-04-01    370
4   1949-05-01    206
..         ...    ...
139 1960-08-01    238
140 1960-09-01    598
141 1960-10-01    180
142 1960-11-01    491
143 1960-12-01    262

You have to re-shape in with enter image description here


Optionally, you can change the colormap of the last column and the last row, in order to improve visibility:

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np


df = pd.DataFrame({'date': pd.date_range(start = '1949-01-01', end = '1960-12-01', freq = 'MS')})
df['value'] = np.random.randint(100, 600, len(df))

df['month'] = df['date'].dt.month_name().str.slice(stop = 3).sort_values()
df['year'] = df['date'].dt.year
df = df.pivot(columns = 'year', index = 'month', values = 'value')

df['month_mean'] = df.mean(axis = 1)
df.loc['year_mean'] = df.mean(axis = 0)

df_values = df.copy()
df_values['month_mean'] = float('nan')
df_values.loc['year_mean'] = float('nan')

df_means = df.copy()
df_means.loc[:-1, :-1] = float('nan')


fig, ax = plt.subplots()

sns.heatmap(ax = ax, data = df_values, annot = True, fmt = '.0f', cmap = 'Reds', vmin = df.to_numpy().min(), vmax = df.to_numpy().max())
sns.heatmap(ax = ax, data = df_means, annot = True, fmt = '.0f', cmap = 'Blues', vmin = df.to_numpy().min(), vmax = df.to_numpy().max())

plt.show()

enter image description here

  • Related