Home > Software design >  How do I use a for loop and the zip function to plot subplots in python?
How do I use a for loop and the zip function to plot subplots in python?

Time:02-18

I'm new to python and stuck on using the zip function to plot subplots. Help is highly appreciated!!

I would like to plot bar charts for my nutrients in the dataframe, where I can distinguish between organic and non-organic. For each country in my dataframe I want subplots. This is my dataframe:

country organic nutrients values
US False carbohydrates 45
US True carbohydrates 41
DE False calcium 37
DE True calcium 31

And this is how my plot should look like: screenshot of expected plot

As I have many countries in my dataframe, I would like to us a for loop and the zip function. But i get an empty canvas when I implement this code:

fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(30, 15))

for i, ax in zip(range(0, 5), axes.flatten()):

    # Filter data for each country
    country_df = df[df["country"] == i]

    # Populate subplots
    ax.bar(country_df["nutrients"], country_df["values"], hue="organic")

    # Add x/y labels
    ax.set_xlabel("nutrients")
    ax.set_ylabel("average value")

Can somebody help? Thanks!!

CodePudding user response:

The problem with your code is when you try to filter by country, but you are actually filtering by a number (which will never match your plots). To easily create the grouping you show in the figure, I recommend using seaborn. A possible solution can be:

import seaborn as sns

countries = set(df['countries'])
# Calculate the number of subplots based on the number of countries
ncols = 3
nrows = math.ceil(len(countries) / ncols)
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(30, 15))

for country, ax in zip(countries, axes.flatten()):
    # Filter data for each country
    country_df = df[df["country"] == country]   # Select by country name

    # Populate subplots
    sns.barplot(data=country_df, x="nutrients", y="values", hue="organic", ax=ax)

    # Add x/y labels
    ax.set_xlabel("nutrients")
    ax.set_ylabel("average value")
  • Related