Home > Blockchain >  Custom annotation of text in seaborn heatmap
Custom annotation of text in seaborn heatmap

Time:01-30

I want to assign different fontsizes for positive and negative values in the following heatmap plotted using seaborn.

    import seaborn as sns  # % matplotlib inline
    import matplotlib.pyplot as plt

    data = np.array([[0.000000, 0.000000], [-0.231049, 0.000000], [0.231049, 0.000000]])
    data = {0: [0.000000, 0.000000], 1: [2.31049, 0.000000], 2: [-0.231049, 0.000000]}
    df = pd.DataFrame.from_dict(data, orient='index')
    sns.heatmap(
        df, cmap='bwr', vmax=10, vmin=0, annot=True, fmt='f',
        linewidths=0.25, annot_kws={"fontsize": 16}, center=0, square=True
    )
    sns.heatmap(
        df, cmap='bwr', vmax=0, vmin=-10, annot=True, fmt='f',
        linewidths=0.25, annot_kws={"fontsize": 6}, center=0, square=True
    )
    plt.show()

I tried to specify the min and max and plot, in two steps but the colors and fonts aren't-displayed right.

Suggestions on how to fix this will be of great help.

enter image description here

CodePudding user response:

To make it easier to keep the properties in sync, the code below uses a for loop. For the positive part, the dataframe is filtered to only contain the positive values. (Internally, pandas fills in NaN for the values that get filtered away, and seaborn leaves those cells blank.)

vmin and vmax are set to the same values for both the negative and positive part of the loop. That way, the colorbar will show all values. To avoid drawing the colorbar twice, cbar=False once.

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

df = pd.DataFrame(np.random.randint(-10, 11, (12, 12)))
fig, ax = plt.subplots()
for posneg in ['pos', 'neg']:
    sns.heatmap(
        df[df > 0] if posneg == 'pos' else df[df < 0],
        cmap='bwr', vmin=-10, vmax=10, center=0, annot=True, fmt='.0f',
        annot_kws={"fontsize": 16 if posneg == 'pos' else 8},
        cbar=(posneg != 'pos'), cbar_kws={'ticks': range(-10, 11, 2)},
        linewidths=0.25, square=True, ax=ax
    )
plt.show()

sns.heatmap with different fonts for pos/neg

  • Related