Home > Software design >  Barplot per each ax in matplotlib
Barplot per each ax in matplotlib

Time:08-20

I have the following dataset, ratings in stars for two fictitious places:

import pandas as pd
import matplotlib.pyplot as plt
df = pd.DataFrame({'id':['A','A','A','A','A','A','A','B','B','B','B','B','B'], 
                   'rating':[1,2,4,5,5,5,3,1,3,3,3,5,2]})

Since the rating is a category (is not a continuous data) I convert it to a category:

df['rating_cat'] = pd.Categorical(df['rating'])

What I want is to create a bar plot per each fictitious place ('A or B'), and the count per each rating. This is the intended plot:

enter image description here

I guess using a for per each value in id could work, but I have some trouble to decide the size:

fig, ax = plt.subplots(1,2,figsize=(6,6))
axs = ax.flatten()
cats = df['rating_cat'].cat.categories.tolist()
ids_uniques = df.id.unique()
for i in range(len(ids_uniques)):
    ax[i].bar(df[df['id']==ids_uniques[i]], df['rating'].size())

But it returns me an error TypeError: 'int' object is not callable

Perhaps it's something complicated what I am doing, please, could you guide me with this code

CodePudding user response:

The pure matplotlib way:

from math import ceil

# Prepare the data for plotting
df_plot = df.groupby(["id", "rating"]).size()
unique_ids = df_plot.index.get_level_values("id").unique()

# Calculate the grid spec. This will be a n x 2 grid
# to fit one chart by id
ncols = 2
nrows = ceil(len(unique_ids) / ncols)
fig = plt.figure(figsize=(6,6))

for i, id_ in enumerate(unique_ids):
    # In a figure grid spanning nrows x ncols, plot into the
    # axes at position i   1
    ax = fig.add_subplot(nrows, ncols, i 1)
    df_plot.xs(id_).plot(axes=ax, kind="bar")

You can simplify things a lot with Seaborn:

import seaborn as sns

sns.catplot(data=df, x="rating", col="id", col_wrap=2, kind="count")

CodePudding user response:

If you're ok with installing a new library, seaborn has a very helpful countplot. Seaborn uses matplotlib under the hood and makes certain plots easier.

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

df = pd.DataFrame({'id':['A','A','A','A','A','A','A','B','B','B','B','B','B'], 
                   'rating':[1,2,4,5,5,5,3,1,3,3,3,5,2]})

sns.countplot(
    data = df,
    x = 'rating',
    hue = 'id',
)
plt.show()
plt.close()

enter image description here

  • Related