Home > OS >  Legends are printing twice when calling matplotlib subplots
Legends are printing twice when calling matplotlib subplots

Time:12-28

I'm writing a code in matplotlib to print multiple histograms under a subplot grid, however, when I call the fig.legend() function at the end, legends from each plot are printing twice. Any guidance on how to resolve this issue would be greatly appreciated:) Here is my code:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('darkgrid')
def get_cmap(n, name='hsv'):
    return plt.cm.get_cmap(name, n)
def isSqrt(n):
   sq_root = int(np.sqrt(n))
   return (sq_root*sq_root) == n
df = pd.read_csv('mpg.csv')
df2 = pd.read_csv('dm_office_sales.csv')
df['miles'] = df2['salary']
numericClassifier = ['int16', 'int32', 'int64', 'float16', 'float32', 'float64']
newdf = df.select_dtypes(numericClassifier)
columns = newdf.columns.tolist()
n = len(columns)
cmap = get_cmap(n)
if(isSqrt(n)):
    nrows = ncols = int(np.sqrt(n))
else:
    ncols = int(np.sqrt(n))
    for i in range(ncols,50):
        if ncols*i >= n:
            nrows = i
            break
        else:
            pass
fig,ax = plt.subplots(nrows,ncols)
count = 0
print(nrows,ncols)
for i in range(0,nrows,1):
    for j in range(0,ncols,1):
        print('ncols = {}'.format(j),'nrows = {}'.format(i),'count = {}'.format(count))
        if count<=n-1:
            plt_new = sns.histplot(df[columns[count]],ax=ax[i,j],facecolor=cmap(count),kde=True,edgecolor='black',label=df[columns[count]].name)
            patches = plt_new.get_children()
            for patch in patches:
                patch.set_alpha(0.8)
            color = patches[0].get_facecolor()
            ax[i,j].set_xlabel('{}'.format(df[columns[count]].name))
            ax[i,j].xaxis.label.set_fontsize(10)
            ax[i,j].xaxis.label.set_fontname('ariel')
            ax[i,j].set(xlabel=None)
            ax[i,j].tick_params(axis='y', labelsize=8)
            count =1
        else:
            break
    
for i in range(0,nrows,1):
    for j in range(0,ncols,1):
        if not ax[i,j].has_data():
            fig.delaxes(ax[i,j])
        else:
            pass

plt.suptitle('Histograms').set_fontname('ariel')
plt.tight_layout()
fig.legend(loc='upper right')
plt.show()

Here is the output: enter image description here

CodePudding user response:

sns.histplot seems to create two bar containers. First a dummy one, and then the real one. (Tested with seaborn 0.12.1; this might work different in other versions.) Therefore, the label gets assigned to both the dummy and the real bar container. A workaround would be to remove the label of the dummy bar container.

Here is the adapted code. Seaborn's mpg dataset is used to have an easily reproducible example. As the first and last color of the hls colormap is red, get_cmap(n 1) ensures n different colors are choosen. Some superfluous code has been removed.

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

def get_cmap(n, name='hsv'):
    return plt.cm.get_cmap(name, n)

sns.set_style('darkgrid')
df = sns.load_dataset('mpg')
numericClassifier = ['int16', 'int32', 'int64', 'float16', 'float32', 'float64']
newdf = df.select_dtypes(numericClassifier)
columns = newdf.columns.tolist()
n = len(columns)
cmap = get_cmap(n   1)
ncols = int(np.sqrt(n))
nrows = int(np.ceil(n / ncols))
fig, ax = plt.subplots(nrows, ncols)
count = 0
print(nrows, ncols)
for i in range(0, nrows):
    for j in range(0, ncols):
        if count < n:
            # print('ncols = {j}; nrows = {i}; count = {count}')
            sns.histplot(df[columns[count]], ax=ax[i, j], facecolor=cmap(count), kde=True, edgecolor='black',
                         label=df[columns[count]].name)
            ax[i, j].containers[0].set_label('')  # seaborn seems to create a dummy bar container, remove its label
            for patch in ax[i, j].get_children():
                patch.set_alpha(0.8)
            ax[i, j].tick_params(axis='y', labelsize=8)
            count  = 1
for i in range(0, nrows):
    for j in range(0, ncols):
        if not ax[i, j].has_data():
            fig.delaxes(ax[i, j])

plt.suptitle('Histograms').set_fontname('ariel')
fig.legend(loc='upper right')
plt.tight_layout()
plt.subplots_adjust(right=0.75) # make extra space for the legend
plt.show()

subplots of sns.histplot with figure legend

Upon further investigation, it seems the dummy bar container isn't created when sns.histplot is called with color= instead of facecolor=.

The code could also be written a bit more "pythonic". This means a.o. trying to avoid repeating code and explicit indices. To achieve this, zip is an important helper. Alongside avoiding repetition, the code becomes shorter and easier to modify. Once you get used to it, it becomes easier to read and to reason about.

The main part could look like e.g.:

fig, axs = plt.subplots(nrows=nrows, ncols=ncols)
for column, ax, color in zip(columns, axs.flat, cmap(range(n))):
    # using `color=` instead of `facecolor=` seems to avoid the creating of dummy bars
    sns.histplot(df[column], ax=ax, color=color, kde=True, edgecolor='black', label=column)
    for patch in ax.get_children():
        patch.set_alpha(0.8)
    ax.tick_params(axis='y', labelsize=8)
for ax in axs.flat:
    if not ax.has_data():
        fig.delaxes(ax)
  • Related