I have a plotting function from a library that takes an array and generates a heatmap from it (I'll use plt.imshow here for the sake of the MWE). The function does not return anything: it just calls plt.show():
import matplotlib.pyplot as plt
import numpy as np
# Complicated function from a library which I technically could but should not modify
# simplified for MWE
def heatmap(arr):
fig, ax = plt.subplots()
_ = ax.imshow(arr)
fig.show()
If I call this function on a loop, then I'll get multiple figures.
for i in range(100):
arr = np.random.rand(10,10)
heatmap(arr)
I want to collect these figures and animate them at the end, like:
plots = []
for i in range(100):
arr = np.random.rand(10,10)
heatmap(arr)
plots.append(plt.gca()) # what should this actually look like?
# wish this existed
plt.animate(plots) # ???
I do have access to the code for heatmap
so I could technically change it to return the figure and axis, but I would like to find a simple solution which would work even if I had no access to the plotting code.
Is this possible with matplotlib? All examples I see in the docs suggest I have to update the figure, and not collect many different ones.
CodePudding user response:
Based on the comments I found a working solution to collect plots generated in a loop without having to access the plotting function, and saving them to an animation.
The original loop I was using was the following:
for i in range(100):
arr = np.random.rand(10,10)
heatmap(arr)
I'll first give the solution, and then a step-by-step explanation of the logic.
Final Solution
plots = []
for i in range(100):
arr = np.random.rand(10,10)
heatmap(arr)
if i==0:
fig, ax = plt.gcf(), plt.gca()
else:
dummy_fig, ax = plt.gcf(), plt.gca()
ax.set(animated=True)
ax.remove()
ax.figure = fig
fig.add_axes(ax)
plt.close(dummy_fig)
plots.append([ax])
ani = animation.ArtistAnimation(fig, plots, interval=50, repeat_delay=200)
ani.save("video.mp4")
Step-by-step explanation
To save the plots and animate them for later, I had to do the following modifications:
- get a handle to the figures and axes generated within the figure:
for i in range(100):
arr = np.random.rand(10,10)
heatmap(arr)
fig, ax = plt.gcf(), plt.gca() # add this
- use the very first figure as a drawing canvas for all future axis:
for i in range(100):
arr = np.random.rand(10,10)
heatmap(arr)
if i==0: # fig is the one we'll use for our animation canvas.
fig, ax = plt.gcf(), plt.gca()
else:
dummy_fig, ax = plt.gcf(), plt.gca() # we will ignore dummy_fig
plt.close(dummy_fig)
- before closing the other figures, move their axis to our main canvas
for i in range(100):
arr = np.random.rand(10,10)
heatmap(arr)
if i==0:
fig, ax = plt.gcf(), plt.gca()
else:
dummy_fig, ax = plt.gcf(), plt.gca()
ax.remove() # remove ax from dummy_fig
ax.figure = fig # now assign it to our canvas fig
fig.add_axes(ax) # also patch the fig axes to know about it
plt.close(dummy_fig)
- set the axes to be animated (doesn't seem to be strictly necessary though)
```python
for i in range(100):
arr = np.random.rand(10,10)
heatmap(arr)
if i==0:
fig, ax = plt.gcf(), plt.gca()
else:
dummy_fig, ax = plt.gcf(), plt.gca()
ax.set(animated=True) # from plt example, but doesn't seem needed
# we could however add info to each plot here, e.g.
# ax.set(xlabel=f"image {i}") # this could be done in i ==0 cond. too.
ax.remove()
ax.figure = fig
fig.add_axes(ax)
plt.close(dummy_fig)
- Now simply collect all of these axes on a list, and plot them.
plots = []
for i in range(100):
arr = np.random.rand(10,10)
heatmap(arr)
if i==0:
fig, ax = plt.gcf(), plt.gca()
else:
dummy_fig, ax = plt.gcf(), plt.gca()
ax.set(animated=True)
ax.remove()
ax.figure = fig
fig.add_axes(ax)
plt.close(dummy_fig)
plots.append([ax])
ani = animation.ArtistAnimation(fig, plots, interval=50, repeat_delay=200)
ani.save("video.mp4")