Home > Back-end >  How to center matplotlib slider below figure and relabel slider label?
How to center matplotlib slider below figure and relabel slider label?

Time:12-24

I've created a function that plays an animation using a slider. Each frame consists of a heatmap (with colorbar) above a barplot. The function's arguments consists of a list of text labels to be used both for the heatmap axes labels as well as the barplot horizontal axis labels, a list of matrices, and a list of lists to be used for the barplot. Also, there is a time window value, labelled win_value, so that frame zero corresponds to time zero, frame one corresponds to win_value, frame two to 2*win_value, and so on.

The code for the function is as follows:

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.widgets import Slider

def heatmap_barplot_animation(labels,M_list,bar_list,win_value):
    num_times=len(M_list)

    fig, ax = plt.subplots(2)
    plt.subplots_adjust(left=None, bottom=.2, right=None, top=.9, wspace=.2, hspace=.2)

    ax_time=fig.add_axes([0.25, 0.1, 0.65, 0.03])
    s_time = Slider(ax_time, 'Time', 0, num_times, valinit=0,valstep=1)

    def update_graph(val):
        i= s_time.val
        ax[0].cla()
        heatmap=ax[0].imshow(M_list[i-1*1],vmin=0, vmax=1, cmap='coolwarm', aspect='auto')

        ax[0].set_xticks(range(len(labels)))
        ax[0].set_xticklabels(labels,fontsize=10,)
        a.x[0].set_yticks(range(len(labels)))
        ax[0].set_yticklabels(labels,fontsize=10)

        ax0_divider = make_axes_locatable(ax[0])
        cax0 = ax0_divider.append_axes('right', size='7%', pad='2%')
        cb = fig.colorbar(heatmap, cax=cax0, orientation='vertical')

        ax[1].cla()
        ax[1].bar(labels,bar_list[i-1])
        ax[1].set_ylim(0, 1)

        plt.show()

    s_time.on_changed(update_graph)
    s_time.set_val(0)

An example with seven labels, 10 frames, and window value .25 seconds:

import random
labels=['a','b','c','d','e','f','g','h']
M_list=[np.random.rand(7,7) for i in range(10)]
bar_list=[[random.uniform(0,1) for i in range(Nc)] for t in range(Nt)]
win_value=.25

heatmap_barplot_animation(labels,M_list,bar_list,win_value)

The third frame of the animation looks like this:

enter image description here

I can't seem to figure out what modifcations are needed to do the following:

  1. Center the slider under the barplot.
  2. Change the barplot slider so that instead of showing the index (3 above), it shows the corresponding time value, .75 seconds in this case.

CodePudding user response:

  1. For your first question, one way to center your slider on the subplots would be to simply adjust the position of your subplots with plt.subplots_adjust to match the axes of the sliders. In your code the axes of your sliders are defined with: ax_time=fig.add_axes([0.25, 0.1, 0.65, 0.03]) so you might want to adjust your subplots with plt.subplots_adjust(left=0.25, bottom=.2, right=None, top=.9, wspace=.2, hspace=.2). You can play around with the axes of your slider and the axes of your subplots to get the results you want (see below for an example with the slider centered on the subplots).

  2. In response to your second question, to relabel the values from the slider you just need to change the valmax and valstep values of your labels to valmax=num_times*win_value and valstep=win_value. To index your M_list and bar_list arrays you then need to declare i as i=int(s_time.val/win_value).

For more details see below the code you provided after implementing the modifications described above:

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.widgets import Slider
import random

def heatmap_barplot_animation(labels,M_list,bar_list,win_value):
    num_times=len(M_list)
    fig, ax = plt.subplots(2)
    
    plt.subplots_adjust(left=0.25, bottom=.2, right=None, top=.9, wspace=.2, hspace=.2)
    ax_time=fig.add_axes([0.25, 0.1, 0.65, 0.03])
    s_time = Slider(ax_time, 'Time',valinit=0,valmin=0,valmax=num_times*win_value,valstep=win_value)

    def update_graph(val):
        i=int(s_time.val/win_value)
        ax[0].cla()
        heatmap=ax[0].imshow(M_list[i-1*1],vmin=0, vmax=1, cmap='coolwarm', aspect='auto')

        ax[0].set_xticks(range(len(labels)))
        ax[0].set_xticklabels(labels,fontsize=10,)
        ax[0].set_yticks(range(len(labels)))
        ax[0].set_yticklabels(labels,fontsize=10)

        ax0_divider = make_axes_locatable(ax[0])
        cax0 = ax0_divider.append_axes('right', size='7%', pad='2%')
        cb = fig.colorbar(heatmap, cax=cax0, orientation='vertical')

        ax[1].cla()
        ax[1].bar(labels,bar_list[i-1])
        ax[1].set_ylim(0, 1)

        plt.show()

    s_time.on_changed(update_graph)
    s_time.set_val(0)

labels=['a','b','c','d','e','f','g','h']
Nc=8
Nt=10
M_list=[np.random.rand(Nc,Nc) for i in range(Nt)]
bar_list=[[random.uniform(0,1) for i in range(Nc)] for t in range(Nt)]
win_value=.25
heatmap_barplot_animation(labels,M_list,bar_list,win_value)

And the output gives (at frame number 3):

enter image description here

  • Related