I have an array X_trj
of shape (18,101)
to be plotted in 3D (they are the trajectories of three different vehicles), and I tried animating my plot by doing the following:
#animate the plot:
import matplotlib.animation as animation
# First, create a function that updates the scatter plot for each frame
def update_plot(n,X_trj,scatters):
# Set the data for each scatter plot
scatters[0].set_offsets(np.stack((X_trj[0, :n], X_trj[1, :n], X_trj[2, :n]), axis=1))
scatters[1].set_offsets(np.stack((X_trj[6, :n], X_trj[7, :n], X_trj[8, :n]), axis=1))
scatters[2].set_offsets(np.stack((X_trj[12,:n], X_trj[13, :n], X_trj[14,:n]), axis=1))
return scatters
# Create the figure and axis
fig = plt.figure()
ax = plt.axes(projection='3d')
# Create the scatter plots
scatters = []
scatters.append(ax.scatter(X_trj[0,:], X_trj[1,:], X_trj[2,:]))
scatters.append(ax.scatter(X_trj[6,:], X_trj[7,:], X_trj[8,:]))
scatters.append(ax.scatter(X_trj[12,:], X_trj[13,:], X_trj[14,:]))
# Set the title
ax.set_title('Trajectory from one-shot optimization (human drones)')
ani = animation.FuncAnimation(fig, update_plot, frames=range(X_trj.shape[1]), fargs=(X_trj, scatters))
plt.show()
ani.save('animation.mp4')
I get the following plot after running the code:
However, when I opened up the mp4
file my animation is not moving. It's the exact same static plot I got. Any help is greatly appreciated!
CodePudding user response:
It is unclear where you copied your starting code from. Most examples use ax.plot
instead of ax.scatter
. Old code can become obsolete with newer matplotlib versions.
Anyway, you fill the full final trajectory already at the initialization. Instead, you should create an empty plot, and manually set the x, y and z limits.
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
# first, fill X_trj with some test data
n = 2000
X_trj = np.random.randn(15, n).cumsum(axis=1)
# second, create a function that updates the scatter plot for each frame
def update_plot(k, X_trj, scatters):
# Set the data for each scatter plot
scatters[0]._offsets3d = X_trj[0:3, :k]
scatters[1]._offsets3d = X_trj[6:9, :k]
scatters[2]._offsets3d = X_trj[12:15, :k]
return scatters
# Create the figure and axis
fig = plt.figure()
ax = plt.axes(projection='3d')
# Create the scatter plots
scatters = []
scatters.append(ax.scatter([], [], []))
scatters.append(ax.scatter([], [], []))
scatters.append(ax.scatter([], [], []))
# set the axis limits
ax.set_xlim3d(X_trj[[0, 6, 12], :].min(), X_trj[[0, 6, 12], :].max())
ax.set_ylim3d(X_trj[[1, 7, 13], :].min(), X_trj[[1, 7, 13], :].max())
ax.set_zlim3d(X_trj[[2, 8, 14], :].min(), X_trj[[2, 8, 14], :].max())
# Set the title
ax.set_title('Trajectory from one-shot optimization (human drones)')
ani = animation.FuncAnimation(fig, update_plot, frames=n, fargs=(X_trj, scatters))
ani.save('animation.mp4')
plt.show()