Home > database >  how to add a common legend for subplots of bar charts with ax.legend()
how to add a common legend for subplots of bar charts with ax.legend()

Time:09-24

unfortunately i have already tried different things i have found so. how can i create a common legend at the top of the figure for my three subplots? The following code didn't work and the execution was stuck, below you will also find the whole code snippet. Actually I thought I could handle it by myself, but i definitely need your help.

fig.tight_layout()
handles, labels = axs[0].get_legend_handles_labels()
fig.legend(handles, labels, fontsize=fs, loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=4)

Here my whole code snippet:


df = pd.read_csv('energy_production_ver4.csv', sep=",")

# Figure Properties
fs = 4  # 30
lw = 2  # 3
width_bars=0.5
ec = 'dimgray'

# Create Subplots
fig, axs = plt.subplots(3, sharex=True, sharey=True, num=None, figsize=(25, 16), dpi=300, facecolor='w',
                        edgecolor='k')  # 26 15
plt.rc('text', usetex=True)
plt.rc('font', family='serif')
plt.rc('font', size=fs-1)
plt.gcf().subplots_adjust(bottom=0.15)

#adjust Time Stamp to desired format

temp_x=df["DateTimeStamp"]
y1 = df['Production']
y2 = df['Consumption']
y3 = df['Production CHP 22kW']
y4 = df['Battery charge']
y5 = df['Battery discharge']
y6 = df['Grid supply']
y7 = df['Feed in grid']

# axs[0].grid(True, linestyle=':')
axs[0].yaxis.grid(linestyle=':')
axs[0].tick_params(axis='both', labelsize=fs)
axs[0].bar(temp_x, y1, label=r'$PV^\mathrm{s}$', color='blue', linewidth=lw / 2, width=width_bars)
axs[0].bar(temp_x, y2, label=r'$EV^\mathrm{s}$', color='orange', linewidth=lw / 2, width=width_bars)
axs[0].bar(temp_x, y3, label=r'CHP', color='green', linewidth=lw / 2, width=width_bars)
axs[0].set_ylabel(r'Energy in MWh', fontsize=fs)


axs[1].yaxis.grid(linestyle=':')
axs[1].tick_params(axis='both', labelsize=fs)
axs[1].bar(temp_x, y4, label=r'$BSS^\mathrm{CH}$', color='red', linewidth=lw / 2, width=width_bars)
axs[1].bar(temp_x, y5, label=r'$BSS^\mathrm{d}$', color='purple', linewidth=lw / 2, width=width_bars)
axs[1].set_ylabel(r'Energy in MWh', fontsize=fs)


axs[2].yaxis.grid(linestyle=':')
axs[2].tick_params(axis='both', labelsize=fs)
axs[2].bar(temp_x, y6, label=r'$GRID^\mathrm{CH}$', color='brown', linewidth=lw / 2, width=width_bars)
axs[2].bar(temp_x, y7, label=r'$GRID^\mathrm{d}$', color='pink', linewidth=lw / 2, width=width_bars)
axs[2].set_ylabel(r'Energy in MWh', fontsize=fs)



plt.xticks(rotation=90)
handles, labels = axs.get_legend_handles_labels() #does not work?
fig.legend(handles, labels, fontsize=fs, loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=4)

plt.show()

CodePudding user response:

The general idea works (I believe this post or this one is your source or works for one axis). The problem here is that you have several axes wrapped in the axs (which is a numpy array) so when you do what you write, you would be getting something like

AttributeError: 'numpy.ndarray' object has no attribute 'get_legend_handles_labels'

The way around this is to go over all the axes and concat the legends (their labels and handles), eg. in the following way:

...
handles,labels=[],[]
for ax in axs.flatten():
    h, l = ax.get_legend_handles_labels()
    handles =h
    labels =l

fig.legend(handles, labels, fontsize=fs, loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=4)
  • Related