I'm writing a code in matplotlib to print multiple histograms under a subplot grid, however, when I call the fig.legend() function at the end, legends from each plot are printing twice. Any guidance on how to resolve this issue would be greatly appreciated:) Here is my code:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('darkgrid')
def get_cmap(n, name='hsv'):
return plt.cm.get_cmap(name, n)
def isSqrt(n):
sq_root = int(np.sqrt(n))
return (sq_root*sq_root) == n
df = pd.read_csv('mpg.csv')
df2 = pd.read_csv('dm_office_sales.csv')
df['miles'] = df2['salary']
numericClassifier = ['int16', 'int32', 'int64', 'float16', 'float32', 'float64']
newdf = df.select_dtypes(numericClassifier)
columns = newdf.columns.tolist()
n = len(columns)
cmap = get_cmap(n)
if(isSqrt(n)):
nrows = ncols = int(np.sqrt(n))
else:
ncols = int(np.sqrt(n))
for i in range(ncols,50):
if ncols*i >= n:
nrows = i
break
else:
pass
fig,ax = plt.subplots(nrows,ncols)
count = 0
print(nrows,ncols)
for i in range(0,nrows,1):
for j in range(0,ncols,1):
print('ncols = {}'.format(j),'nrows = {}'.format(i),'count = {}'.format(count))
if count<=n-1:
plt_new = sns.histplot(df[columns[count]],ax=ax[i,j],facecolor=cmap(count),kde=True,edgecolor='black',label=df[columns[count]].name)
patches = plt_new.get_children()
for patch in patches:
patch.set_alpha(0.8)
color = patches[0].get_facecolor()
ax[i,j].set_xlabel('{}'.format(df[columns[count]].name))
ax[i,j].xaxis.label.set_fontsize(10)
ax[i,j].xaxis.label.set_fontname('ariel')
ax[i,j].set(xlabel=None)
ax[i,j].tick_params(axis='y', labelsize=8)
count =1
else:
break
for i in range(0,nrows,1):
for j in range(0,ncols,1):
if not ax[i,j].has_data():
fig.delaxes(ax[i,j])
else:
pass
plt.suptitle('Histograms').set_fontname('ariel')
plt.tight_layout()
fig.legend(loc='upper right')
plt.show()
CodePudding user response:
sns.histplot
seems to create two bar containers. First a dummy one, and then the real one. (Tested with seaborn 0.12.1; this might work different in other versions.) Therefore, the label gets assigned to both the dummy and the real bar container. A workaround would be to remove the label of the dummy bar container.
Here is the adapted code. Seaborn's mpg
dataset is used to have an easily reproducible example. As the first and last color of the hls
colormap is red, get_cmap(n 1)
ensures n
different colors are choosen. Some superfluous code has been removed.
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
def get_cmap(n, name='hsv'):
return plt.cm.get_cmap(name, n)
sns.set_style('darkgrid')
df = sns.load_dataset('mpg')
numericClassifier = ['int16', 'int32', 'int64', 'float16', 'float32', 'float64']
newdf = df.select_dtypes(numericClassifier)
columns = newdf.columns.tolist()
n = len(columns)
cmap = get_cmap(n 1)
ncols = int(np.sqrt(n))
nrows = int(np.ceil(n / ncols))
fig, ax = plt.subplots(nrows, ncols)
count = 0
print(nrows, ncols)
for i in range(0, nrows):
for j in range(0, ncols):
if count < n:
# print('ncols = {j}; nrows = {i}; count = {count}')
sns.histplot(df[columns[count]], ax=ax[i, j], facecolor=cmap(count), kde=True, edgecolor='black',
label=df[columns[count]].name)
ax[i, j].containers[0].set_label('') # seaborn seems to create a dummy bar container, remove its label
for patch in ax[i, j].get_children():
patch.set_alpha(0.8)
ax[i, j].tick_params(axis='y', labelsize=8)
count = 1
for i in range(0, nrows):
for j in range(0, ncols):
if not ax[i, j].has_data():
fig.delaxes(ax[i, j])
plt.suptitle('Histograms').set_fontname('ariel')
fig.legend(loc='upper right')
plt.tight_layout()
plt.subplots_adjust(right=0.75) # make extra space for the legend
plt.show()
Upon further investigation, it seems the dummy bar container isn't created when sns.histplot
is called with color=
instead of facecolor=
.
The code could also be written a bit more "pythonic". This means a.o. trying to avoid repeating code and explicit indices. To achieve this, zip
is an important helper. Alongside avoiding repetition, the code becomes shorter and easier to modify. Once you get used to it, it becomes easier to read and to reason about.
The main part could look like e.g.:
fig, axs = plt.subplots(nrows=nrows, ncols=ncols)
for column, ax, color in zip(columns, axs.flat, cmap(range(n))):
# using `color=` instead of `facecolor=` seems to avoid the creating of dummy bars
sns.histplot(df[column], ax=ax, color=color, kde=True, edgecolor='black', label=column)
for patch in ax.get_children():
patch.set_alpha(0.8)
ax.tick_params(axis='y', labelsize=8)
for ax in axs.flat:
if not ax.has_data():
fig.delaxes(ax)