Home > OS >  How to create stacked bar chart from a multi-level pivot table
How to create stacked bar chart from a multi-level pivot table

Time:08-19

I've created a pivot table with two levels of columns

pivotCust = bigData.pivot_table(index=['month'],columns=['year','usertype'],values='start_time',aggfunc = 'count')

This creates the table that I'm interested in:

year    2019    2020    2021
usertype    casual  member  casual  member  casual  member
month                       
1   4602    98670   7785    136099  18117   78717
2   2638    93548   12870   126715  10131   39491
3   15923   149688  27825   115593  84033   144463
4   47744   217566  23628   61148   136601  200629
5   81624   285834  86909   113365  256916  274717
6   130218  345177  154718  188287  370681  358914
7   175632  381683  269296  282184  442056  380354
8   186889  403295  289661  332700  412671  391681
9   129173  364046  230692  302266  363890  392257
10  71035   300751  145012  243641  257242  373984
11  18729   158447  88099   171617  106929  253049
12  16430   138662  30080   101493  69738   177802

But when I try to turn it into a bar graph (with the code below), it's hard to read, as it creates 72 columns -- six entries per month (casual/member * 3 years), for 12 months. enter image description here

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import io #just used for reading in the example table

pivotCust = pd.read_csv(io.StringIO("""
1   4602    98670   7785    136099  18117   78717
2   2638    93548   12870   126715  10131   39491
3   15923   149688  27825   115593  84033   144463
4   47744   217566  23628   61148   136601  200629
5   81624   285834  86909   113365  256916  274717
6   130218  345177  154718  188287  370681  358914
7   175632  381683  269296  282184  442056  380354
8   186889  403295  289661  332700  412671  391681
9   129173  364046  230692  302266  363890  392257
10  71035   300751  145012  243641  257242  373984
11  18729   158447  88099   171617  106929  253049
12  16430   138662  30080   101493  69738   177802"""
),delim_whitespace=True, header=None, index_col=0)

pivotCust.index.name = 'month'

pivotCust.columns = pd.MultiIndex.from_product([
    [2019,2020,2021],
    ['casual','member'],
], names=['year','usertype'])


plot_df = pivotCust.reset_index().melt(id_vars='month')
plot_df['casual_member_sum'] = plot_df.groupby(['month','year'])['value'].transform('sum')

fig,ax = plt.subplots()

#Plot the sum of the two categories as background bars
sns.barplot(
    x = 'month',
    y = 'casual_member_sum',
    palette = 'Blues',
    hue = 'Total ' plot_df['year'].astype(str),
    ax = ax,
    data = plot_df,
)

#Plot just the members as foreground bars
sns.barplot(
    x = 'month',
    y = 'value',
    palette = 'Reds',
    hue = 'Member ' plot_df['year'].astype(str),
    ax = ax,
    data = plot_df[plot_df['usertype'].eq('member')],
)

plt.show()
plt.close()

Here's the lineplot approach with seaborn using the same plot_df created above. The lineplot is easy to make too

enter image description here

sns.lineplot(
    x = 'month',
    y = 'value',
    hue = 'year',
    style = 'usertype',
    data = plot_df,
)
plt.show()
plt.close()
  • Related