I am trying to plot a grouped bar plot, but the bars are shifted. In the example below, I want the x ticks to be centered between the green and orange bars, so that the whole stack is centered.
x = [56.35035714285714, 49.573170731707314, 48.100462962962965, 47.098684210526315, 44.452991452991455, 46.82525252525252, 46.99393939393939, 47.09538461538461, 46.35581395348837, 46.05, 46.38235294117647, 47.60588235294117, 42.7875, 44.25, 43.4]
y = [56.316428571428574, 51.560278745644595, 49.66481481481482, 48.367105263157896, 46.80427350427351, 49.59494949494949, 48.94242424242424, 48.395384615384614, 47.97674418604651, 45.86, 48.88823529411765, 48.48823529411764, 44.1875, 49.150000000000006, 44.76666666666667]
z = [59.60678571428572, 52.26794425087108, 51.4324074074074, 49.89934210526316, 47.705128205128204, 49.269696969696966, 49.35454545454546, 48.956923076923076, 48.36046511627907, 48.07, 50.33529411764705, 49.92941176470588, 45.65, 49.25, 46.43333333333334]
r = [59.59571428571428, 53.78989547038328, 52.50092592592593, 49.732236842105266, 48.32905982905983, 48.462626262626266, 49.306060606060605, 49.792307692307695, 48.37674418604651, 48.263333333333335, 49.95882352941176, 48.599999999999994, 45.7875, 50.2, 50.93333333333334]
import numpy as np
import matplotlib.pyplot as plt
x_labels = ["0-5","5-10","10-15","15-20","20-25","25,30","30-35","35-40","40-45","45-50","50-55","55,60","60-65","65,70","70-75"]
w=0.2
bar1 = np.arange(len(x_labels))
bar2 = [i w for i in bar1 ]
bar3 = [i w for i in bar2 ]
bar4 = [i w for i in bar3 ]
plt.figure(figsize=(20,5))
plt.bar(bar1,x,w,label="A")
plt.bar(bar2,y,w,label="B")
plt.bar(bar3,z,w,label="C")
plt.bar(bar4,r,w,label="D")
plt.xlabel("x-axis")
plt.ylabel("y-axis")
plt.title("bar chart")
plt.xticks(bar1,x_labels)
plt.legend()
plt.show()
Output:
CodePudding user response:
Since your current bar locations are defined relative to bar1
, the quickest fix would be to shift the ticks by 1.5*w
(where 1.5 comes from number_of_groups/2 - 0.5
):
plt.xticks(bar1 1.5*w, x_labels)
But this will put your ticks at non-standard locations [0.3, 1.3, 2.3, ...]
(not visible to the end user, but possibly more confusing to maintain).
The more standard approach is to use fixed pivot
points and shift the bar/tick locations relative to the pivot
points:
pivot = np.arange(len(x_labels)) # pivot points for bars/ticks
bar1 = [p - 1.5*w for p in pivot] # shift bars relative to pivot points
bar2 = [p - 0.5*w for p in pivot]
bar3 = [p 0.5*w for p in pivot]
bar4 = [p 1.5*w for p in pivot]
# ...
plt.xticks(pivot, x_labels) # place ticks at pivot points, not bar1
Or as a list comprehension:
pivot = np.arange(len(x_labels))
bar1, bar2, bar3, bar4 = [[p k*w for p in pivot] for k in (-1.5, -0.5, 0.5, 1.5)]
# ...
plt.xticks(pivot, x_labels)
Full code:
x = [56.35035714285714, 49.573170731707314, 48.100462962962965, 47.098684210526315, 44.452991452991455, 46.82525252525252, 46.99393939393939, 47.09538461538461, 46.35581395348837, 46.05, 46.38235294117647, 47.60588235294117, 42.7875, 44.25, 43.4]
y = [56.316428571428574, 51.560278745644595, 49.66481481481482, 48.367105263157896, 46.80427350427351, 49.59494949494949, 48.94242424242424, 48.395384615384614, 47.97674418604651, 45.86, 48.88823529411765, 48.48823529411764, 44.1875, 49.150000000000006, 44.76666666666667]
z = [59.60678571428572, 52.26794425087108, 51.4324074074074, 49.89934210526316, 47.705128205128204, 49.269696969696966, 49.35454545454546, 48.956923076923076, 48.36046511627907, 48.07, 50.33529411764705, 49.92941176470588, 45.65, 49.25, 46.43333333333334]
r = [59.59571428571428, 53.78989547038328, 52.50092592592593, 49.732236842105266, 48.32905982905983, 48.462626262626266, 49.306060606060605, 49.792307692307695, 48.37674418604651, 48.263333333333335, 49.95882352941176, 48.599999999999994, 45.7875, 50.2, 50.93333333333334]
import numpy as np
import matplotlib.pyplot as plt
x_labels = ['0-5', '5-10', '10-15', '15-20', '20-25', '25-30', '30-35', '35-40', '40-45', '45-50', '50-55', '55-60', '60-65', '65-70', '70-75']
w = 0.2
pivot = np.arange(len(x_labels)) # pivot points for bars/ticks
bar1, bar2, bar3, bar4 = [[p k*w for p in pivot] for k in (-1.5, -0.5, 0.5, 1.5)]
plt.figure(figsize=(20, 5))
plt.bar(bar1, x, w, label='A')
plt.bar(bar2, y, w, label='B')
plt.bar(bar3, z, w, label='C')
plt.bar(bar4, r, w, label='D')
plt.xlabel('x-axis')
plt.ylabel('y-axis')
plt.title('bar chart')
plt.xticks(pivot, x_labels) # place ticks at pivot points, not bar1
plt.legend()
plt.show()