I would like to add a colorbar to some data to make an animation.
However, I keep on creating new color bars in the figure and don't know how to remove them.
A reproducible example is:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from matplotlib import cm
fig = plt.figure(figsize=(7, 7))
ax = fig.add_subplot(111, projection='3d', proj_type='ortho')
X = np.arange(-5, 5, 0.25)
Y = np.arange(-5, 5, 0.25)
X, Y = np.meshgrid(X, Y)
R = np.sqrt(X**2 Y**2)
Z = np.sin(R)
data = [X, Y, Z]
def myPlot(ax, data):
surf = ax.plot_surface(*data, cmap=cm.coolwarm,
linewidth=0, antialiased=False)
fig.colorbar(surf, shrink=0.5, aspect=5)
def anime(i, ax, data):
ax.cla()
data[0] = 0.1
myPlot(ax, data)
animation = FuncAnimation(
fig,
anime,
frames=range(len(X)),
fargs=(ax, data)
)
How can I keep only one colorbar in the animation?
Kind regards
CodePudding user response:
You have to create an axis for the colorbar:
gs = GridSpec(1, 2, width_ratios = [0.9, 0.05])
fig = plt.figure(figsize = (7, 7))
ax = fig.add_subplot(gs[0], projection = '3d', proj_type = 'ortho')
cbar_ax = fig.add_subplot(gs[1])
and pass it as a parameter in the colorbar definition:
fig.colorbar(surf, shrink = 0.5, aspect = 5, cax = cbar_ax)
Complete Code
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from matplotlib import cm
from matplotlib.gridspec import GridSpec
gs = GridSpec(1, 2, width_ratios = [0.9, 0.05])
fig = plt.figure(figsize = (7, 7))
ax = fig.add_subplot(gs[0], projection = '3d', proj_type = 'ortho')
cbar_ax = fig.add_subplot(gs[1])
X = np.arange(-5, 5, 0.25)
Y = np.arange(-5, 5, 0.25)
X, Y = np.meshgrid(X, Y)
R = np.sqrt(X**2 Y**2)
Z = np.sin(R)
data = [X, Y, Z]
def myPlot(ax, data):
surf = ax.plot_surface(*data, cmap = cm.coolwarm,
linewidth = 0, antialiased = False)
fig.colorbar(surf, shrink = 0.5, aspect = 5, cax = cbar_ax)
def anime(i, ax, data):
ax.cla()
data[0] = 0.1
myPlot(ax, data)
animation = FuncAnimation(
fig,
anime,
frames = range(len(X)),
fargs = (ax, data)
)
plt.show()