(Adjusted to suggestions) I already have a function that performs some plot:
def plot_i(Y, ax = None):
if ax == None:
ax = plt.gca()
fig = plt.figure()
ax.plot(Y)
plt.close(fig)
return fig
And I wish to use this to plot in a grid for n arrays. Let's assume the grid is (n // 2, 2) for simplicity and that n is even. At the moment, I came up with this:
def multi_plot(Y_arr, function):
n = len(Y_arr)
fig, ax = plt.subplots(n // 2, 2)
for i in range(n):
# assign to one axis a call of the function = plot_i that draws a plot
plt.close(fig)
return fig
Unfortunately, what I get if I do something like:
# inside the loop
plot_i(Y[:, i], ax = ax[k,j])
Is correct but I need to close figures each time at the end, otherwise I keep on adding figures to plt. Is there any way I can avoid calling each time plt.close(fig)?
CodePudding user response:
If I understand correctly, you are looking for something like this:
import numpy as np
import matplotlib.pyplot as plt
def plot_i(Y, ax=None):
if ax == None:
ax = plt.gca()
ax.plot(Y)
return
def multi_plot(Y_arr, function, n_cols=2):
n = Y_arr.shape[1]
fig, ax = plt.subplots(n // n_cols (1 if n % n_cols else 0), n_cols)
for i in range(n):
# assign to one axis a call of the function = plot_i that draws a plot
function(Y_arr[:, i], ax = ax[i//n_cols, i%n_cols])
return fig
if __name__ == '__main__':
x = np.linspace(0,12.6, 100)
# let's create some fake data
data = np.exp(-np.linspace(0,.5, 14)[np.newaxis, :] * x[:, np.newaxis]) * np.sin(x[:, np.newaxis])
fig = multi_plot(data, plot_i, 3)
Be careful when using gca()
: it will create a new figure if there is no figure active.