I've created a function that plays an animation using a slider. Each frame consists of a heatmap (with colorbar) above a barplot. The function's arguments consists of a list of text labels to be used both for the heatmap axes labels as well as the barplot horizontal axis labels, a list of matrices, and a list of lists to be used for the barplot. Also, there is a time window value, labelled win_value
, so that frame zero corresponds to time zero, frame one corresponds to win_value
, frame two to 2*win_value
, and so on.
The code for the function is as follows:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.widgets import Slider
def heatmap_barplot_animation(labels,M_list,bar_list,win_value):
num_times=len(M_list)
fig, ax = plt.subplots(2)
plt.subplots_adjust(left=None, bottom=.2, right=None, top=.9, wspace=.2, hspace=.2)
ax_time=fig.add_axes([0.25, 0.1, 0.65, 0.03])
s_time = Slider(ax_time, 'Time', 0, num_times, valinit=0,valstep=1)
def update_graph(val):
i= s_time.val
ax[0].cla()
heatmap=ax[0].imshow(M_list[i-1*1],vmin=0, vmax=1, cmap='coolwarm', aspect='auto')
ax[0].set_xticks(range(len(labels)))
ax[0].set_xticklabels(labels,fontsize=10,)
a.x[0].set_yticks(range(len(labels)))
ax[0].set_yticklabels(labels,fontsize=10)
ax0_divider = make_axes_locatable(ax[0])
cax0 = ax0_divider.append_axes('right', size='7%', pad='2%')
cb = fig.colorbar(heatmap, cax=cax0, orientation='vertical')
ax[1].cla()
ax[1].bar(labels,bar_list[i-1])
ax[1].set_ylim(0, 1)
plt.show()
s_time.on_changed(update_graph)
s_time.set_val(0)
An example with seven labels, 10 frames, and window value .25 seconds:
import random
labels=['a','b','c','d','e','f','g','h']
M_list=[np.random.rand(7,7) for i in range(10)]
bar_list=[[random.uniform(0,1) for i in range(Nc)] for t in range(Nt)]
win_value=.25
heatmap_barplot_animation(labels,M_list,bar_list,win_value)
The third frame of the animation looks like this:
I can't seem to figure out what modifcations are needed to do the following:
- Center the slider under the barplot.
- Change the barplot slider so that instead of showing the index (3 above), it shows the corresponding time value, .75 seconds in this case.
CodePudding user response:
For your first question, one way to center your slider on the subplots would be to simply adjust the position of your subplots with
plt.subplots_adjust
to match the axes of the sliders. In your code the axes of your sliders are defined with:ax_time=fig.add_axes([0.25, 0.1, 0.65, 0.03])
so you might want to adjust your subplots withplt.subplots_adjust(left=0.25, bottom=.2, right=None, top=.9, wspace=.2, hspace=.2)
. You can play around with the axes of your slider and the axes of your subplots to get the results you want (see below for an example with the slider centered on the subplots).In response to your second question, to relabel the values from the slider you just need to change the
valmax
andvalstep
values of your labels tovalmax=num_times*win_value
andvalstep=win_value
. To index yourM_list
andbar_list
arrays you then need to declarei
asi=int(s_time.val/win_value)
.
For more details see below the code you provided after implementing the modifications described above:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.widgets import Slider
import random
def heatmap_barplot_animation(labels,M_list,bar_list,win_value):
num_times=len(M_list)
fig, ax = plt.subplots(2)
plt.subplots_adjust(left=0.25, bottom=.2, right=None, top=.9, wspace=.2, hspace=.2)
ax_time=fig.add_axes([0.25, 0.1, 0.65, 0.03])
s_time = Slider(ax_time, 'Time',valinit=0,valmin=0,valmax=num_times*win_value,valstep=win_value)
def update_graph(val):
i=int(s_time.val/win_value)
ax[0].cla()
heatmap=ax[0].imshow(M_list[i-1*1],vmin=0, vmax=1, cmap='coolwarm', aspect='auto')
ax[0].set_xticks(range(len(labels)))
ax[0].set_xticklabels(labels,fontsize=10,)
ax[0].set_yticks(range(len(labels)))
ax[0].set_yticklabels(labels,fontsize=10)
ax0_divider = make_axes_locatable(ax[0])
cax0 = ax0_divider.append_axes('right', size='7%', pad='2%')
cb = fig.colorbar(heatmap, cax=cax0, orientation='vertical')
ax[1].cla()
ax[1].bar(labels,bar_list[i-1])
ax[1].set_ylim(0, 1)
plt.show()
s_time.on_changed(update_graph)
s_time.set_val(0)
labels=['a','b','c','d','e','f','g','h']
Nc=8
Nt=10
M_list=[np.random.rand(Nc,Nc) for i in range(Nt)]
bar_list=[[random.uniform(0,1) for i in range(Nc)] for t in range(Nt)]
win_value=.25
heatmap_barplot_animation(labels,M_list,bar_list,win_value)
And the output gives (at frame number 3):