Home > Software design >  pyplot: create subplots by looping - axs[i] interpreted as tuple?
pyplot: create subplots by looping - axs[i] interpreted as tuple?

Time:01-15

I am trying to create subplots depicting the frequencies of random values from a normal distribution.

import matplotlib.pyplot as plt
import scipy.stats

fig, axs = plt.subplots(3,1, sharex = True)
i = 0

axs = axs.ravel()
for n in [100, 1000, 10000]:
    random_variables = scipy.stats.norm.rvs(loc=10, scale=3, size=n)   
    axs[i] = plt.hist(random_variables)
    axs[i].set_ylabel('number of occurrences of this value')
    i  = 1
axs.set_xlabel('value of random sample')

plt.tight_layout()
plt.show()

For axs[i].set_ylabel('number of occurrences of this value') I get AttributeError: 'tuple' object has no attribute 'set_ylabel' and I'm also very confused since (I think) all iterations get plottet into the third histogram. Why is axs[i] interpreted as a tuple? I thought by axs.ravel() would make it possible to iterate but obviously I was wrong. Could you give me a tip please?

CodePudding user response:

It looks like the issue is that you're reassigning axs[i] to the result of the plt.hist() function, which returns a tuple of (n, bins, patches). So, when you try to call set_ylabel on axs[i], it is trying to call it on a tuple, which doesn't have that method.

Instead, you can use the ax parameter of the plt.hist() function to specify which subplot to plot the histogram on. For example:

axs = axs.ravel()
for n in [100, 1000, 10000]:
    random_variables = scipy.stats.norm.rvs(loc=10, scale=3, size=n)   
    axs[i].hist(random_variables, ax=axs[i])  # <- here
    axs[i].set_ylabel('number of occurrences of this value')
    i  = 1

I also noticed that you only share the x axis, you can also share the y axis (sharey=True):

fig, axs = plt.subplots(3,1, sharex = True, sharey=True)

CodePudding user response:

An idiom I often use is:

fig, axes = plt.subplots(nrows=3, ncols=1, sharex=True)
for ax, n in zip(axes.ravel(), [100, 1000, 10000]):
    ...

Sometimes, when the number of plots doesn't fit a nice nrows * ncols grid, I like to use zip_longest() and fig.delaxes(ax) to eliminate the empty plots, as in this answer.

  • Related