I am working with a relative abundance plot like this:
But need help with displaying the total amount of the values of the categories on top of each bar. The result should be this:
Aka. the counted amount of each x-axis category ('1-2','2-3','3-4','4-5').
My code:
import pandas as pd
import matplotlib.pyplot as plt
data = {'Index A':[1,5,8,9],'Index B':[19,3,7,9],'Index C':[14,16,7,8]}
stacked_data = pd.DataFrame(data, index=['1-2','2-3','3-4','4-5'])
stacked_data = stacked_data.apply(lambda x: x*100/sum(x), axis=1)
list_of_colors = ['blue','red','yellow']
ax = stacked_data.plot(kind='bar', stacked=True, width=1, figsize=(9,8), edgecolor=None, color=list_of_colors)
plt.xlabel('Size', fontsize=15)
plt.ylabel('Relative abundance [%]', fontsize=15)
plt.xticks(fontsize=12, rotation=360)
plt.yticks(fontsize=12)
plt.legend(bbox_to_anchor=(1, .965), facecolor='white',loc=0, frameon=True, fontsize=12,title='Index')
plt.show()
Any help is appreciated!
CodePudding user response:
The total value is obtained before calculating the relative values of the data frames. Annotate with a list of those total values and a list of the retrieved labels for the x-axis of the graph. The coordinate basis for the annotation is data-based.
import pandas as pd
import matplotlib.pyplot as plt
data = {'Index A':[1,5,8,9],'Index B':[19,3,7,9],'Index C':[14,16,7,8]}
stacked_data = pd.DataFrame(data, index=['1-2','2-3','3-4','4-5'])
total_data = stacked_data.sum(axis=1)
# print(total_data)
stacked_data = stacked_data.apply(lambda x: x*100/sum(x), axis=1)
list_of_colors = ['blue','red','yellow']
ax = stacked_data.plot(kind='bar', stacked=True, width=1, figsize=(9,8), edgecolor=None, color=list_of_colors)
plt.xlabel('Size', fontsize=15)
plt.ylabel('Relative abundance [%]', fontsize=15)
plt.xticks(fontsize=12, rotation=360)
plt.yticks(fontsize=12)
plt.legend(bbox_to_anchor=(1, .965), facecolor='white',loc=0, frameon=True, fontsize=12, title='Index')
# print(ax.get_xticklabels())
for i,(n,idx) in enumerate(zip(total_data,ax.get_xticklabels())):
ax.text(x=idx.get_position()[0], y=100, s='n={}'.format(n), ha='center', fontdict={'size':18}, transform=ax.transData)
plt.show()