Home > Net >  animate multiple figures generated in a for loop
animate multiple figures generated in a for loop

Time:06-20

I have a plotting function from a library that takes an array and generates a heatmap from it (I'll use plt.imshow here for the sake of the MWE). The function does not return anything: it just calls plt.show():

import matplotlib.pyplot as plt
import numpy as np

# Complicated function from a library which I technically could but should not modify
# simplified for MWE
def heatmap(arr):
    fig, ax = plt.subplots()
    _ = ax.imshow(arr)
    fig.show()

If I call this function on a loop, then I'll get multiple figures.

for i in range(100):
    arr = np.random.rand(10,10)
    heatmap(arr)

I want to collect these figures and animate them at the end, like:

plots = []
for i in range(100):
    arr = np.random.rand(10,10)
    heatmap(arr)
    plots.append(plt.gca())  # what should this actually look like?

# wish this existed
plt.animate(plots) # ???

I do have access to the code for heatmap so I could technically change it to return the figure and axis, but I would like to find a simple solution which would work even if I had no access to the plotting code.

Is this possible with matplotlib? All examples I see in the docs suggest I have to update the figure, and not collect many different ones.

CodePudding user response:

Based on the comments I found a working solution to collect plots generated in a loop without having to access the plotting function, and saving them to an animation.

The original loop I was using was the following:

for i in range(100):
    arr = np.random.rand(10,10)
    heatmap(arr)

I'll first give the solution, and then a step-by-step explanation of the logic.

Final Solution

plots = []
for i in range(100):
    arr = np.random.rand(10,10)
    heatmap(arr)
    if i==0:
        fig, ax = plt.gcf(), plt.gca()
    else:
        dummy_fig, ax = plt.gcf(), plt.gca()
        ax.set(animated=True)
        ax.remove()
        ax.figure = fig
        fig.add_axes(ax)
        plt.close(dummy_fig)
        
    plots.append([ax])

ani = animation.ArtistAnimation(fig, plots, interval=50, repeat_delay=200)
ani.save("video.mp4")

Step-by-step explanation

To save the plots and animate them for later, I had to do the following modifications:

  1. get a handle to the figures and axes generated within the figure:
for i in range(100):
    arr = np.random.rand(10,10)
    heatmap(arr)
    fig, ax = plt.gcf(), plt.gca()  # add this
  1. use the very first figure as a drawing canvas for all future axis:
for i in range(100):
    arr = np.random.rand(10,10)
    heatmap(arr)
    if i==0:  # fig is the one we'll use for our animation canvas.
        fig, ax = plt.gcf(), plt.gca()
    else:
        dummy_fig, ax = plt.gcf(), plt.gca()  # we will ignore dummy_fig
        plt.close(dummy_fig)
  1. before closing the other figures, move their axis to our main canvas
for i in range(100):
    arr = np.random.rand(10,10)
    heatmap(arr)
    if i==0:
        fig, ax = plt.gcf(), plt.gca()
    else:
        dummy_fig, ax = plt.gcf(), plt.gca()
        ax.remove()  # remove ax from dummy_fig
        ax.figure = fig  # now assign it to our canvas fig
        fig.add_axes(ax)  # also patch the fig axes to know about it
        plt.close(dummy_fig)
  1. set the axes to be animated (doesn't seem to be strictly necessary though)
```python
for i in range(100):
    arr = np.random.rand(10,10)
    heatmap(arr)
    if i==0:
        fig, ax = plt.gcf(), plt.gca()
    else:
        dummy_fig, ax = plt.gcf(), plt.gca()
        ax.set(animated=True)  # from plt example, but doesn't seem needed
        # we could however add info to each plot here, e.g.
        # ax.set(xlabel=f"image {i}")  # this could be done in i ==0 cond. too.
        ax.remove()
        ax.figure = fig 
        fig.add_axes(ax)
        plt.close(dummy_fig)
  1. Now simply collect all of these axes on a list, and plot them.
plots = []
for i in range(100):
    arr = np.random.rand(10,10)
    heatmap(arr)
    if i==0:
        fig, ax = plt.gcf(), plt.gca()
    else:
        dummy_fig, ax = plt.gcf(), plt.gca()
        ax.set(animated=True)
        ax.remove()
        ax.figure = fig
        fig.add_axes(ax)
        plt.close(dummy_fig)
        
    plots.append([ax])

ani = animation.ArtistAnimation(fig, plots, interval=50, repeat_delay=200)
ani.save("video.mp4")
  • Related