Home > Blockchain >  How to avoid multiple colorbars when looping seaborn heatmap
How to avoid multiple colorbars when looping seaborn heatmap

Time:03-25

I made 3 plots in a for loop with the following code. But the first figure I create has 3 colorbars, the 2nd has 2 and the 1st has 1. So it seems the previous plot adds a colorbar to the current plot. How can I avoid this?

for f in files:
    #print(f)
    roi_signals = pd.read_csv(f, sep='\t')
   
    fig = sns.heatmap(roi_signals)
    fig_name = f.replace('.txt', '.png')
    plt.savefig(fig_name)

CodePudding user response:

By default, sns.heatmap plots onto the existing Axes and allocates space for a colorbar:

This will draw the heatmap into the currently-active Axes if none is provided to the ax argument. Part of this Axes space will be taken and used to plot a colormap, unless cbar is False or a separate Axes is provided to cbar_ax.


The simplest solution is to clear the figure each iteration with plt.clf:

for f in files:
    roi_signals = pd.read_csv(f, sep='\t')

    sns.heatmap(roi_signals)
    plt.savefig(f.replace('.txt', '.png')
    plt.clf() # clear figure before next iteration

Or specify cbar_ax to overwrite the previous iteration's colorbar:

fig = plt.figure()
for i, f in enumerate(files):
    roi_signals = pd.read_csv(f, sep='\t')

    cbar_ax = fig.axes[-1] if i else None # retrieve previous cbar_ax (if exists)
    sns.heatmap(roi_signals, cbar_ax=cbar_ax)
    plt.savefig(f.replace('.txt', '.png')

Or just create a new figure per iteration, but this is not recommended for many iterations:

for f in files:
    roi_signals = pd.read_csv(f, sep='\t')

    fig = plt.figure() # create new figure each iteration
    sns.heatmap(roi_signals)
    plt.savefig(f.replace('.txt', '.png')

CodePudding user response:

I just fixed it with this:

import matlabplot.pyplot as plt
import seaborn as sns
import pandas as pd

for f in files:
    #print(f)
    roi_signals = pd.read_csv(f, sep='\t')

    fig, ax = plt.subplots()
    ax = sns.heatmap(roi_signals)
    fig_name = f.replace('.txt', '.png')
    plt.savefig(fig_name)
  • Related