Home > Software design >  Remove for loops when plotting matplotlib subplots
Remove for loops when plotting matplotlib subplots

Time:06-06

I have large subplot-based figure to produce in python using matplotlib. In total the figure has in excess of 500 individual plots each with 1000s of datapoints. This can be plotted using a for loop-based approach modelled on the minimum example given below

import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

# define main plot names and subplot names
mains = ['A','B','C','D']
subs = list(range(9))

# generate mimic data in pd dataframe
col = [letter str(number) for letter in mains for number in subs]
col.insert(0,'Time')

df = pd.DataFrame(columns=col)

for title in df.columns:
    df[title] = [i for i in range(100)]

# although alphabet and mains are the same in this minimal example this may not always be true
alphabet = ['A', 'B', 'C', 'D']
column_names = [column for column in df.columns if column != 'Time']

# define figure size and main gridshape
fig = plt.figure(figsize=(15, 15))
outer = gridspec.GridSpec(2, 2, wspace=0.2, hspace=0.2)

for i, letter in enumerate(alphabet):
    # define inner grid size and shape
    inner = gridspec.GridSpecFromSubplotSpec(3, 3,
                    subplot_spec=outer[i], wspace=0.1, hspace=0.1)

    # select only columns with correct letter
    plot_array = [col for col in column_names if col.startswith(letter)]      
        
    # set title for each letter plot
    ax = plt.Subplot(fig, outer[i])
    ax.set_title(f'Letter {letter}')
    ax.axis('off')
    fig.add_subplot(ax)
    
    # create each subplot
    for j, col in enumerate(plot_array):
        ax = plt.Subplot(fig, inner[j])   
        
        X = df['Time']
        Y = df[col]
    
        # plot waveform
        ax.plot(X, Y)

        # hide all axis ticks
        ax.axis('off')
        
        # set y_axis limits so all plots share same y_axis  
        ax.set_ylim(df[column_names].min().min(),df[column_names].max().max())
        
        fig.add_subplot(ax)  

However this is slow, requiring minutes to plot the figure. Is there a more efficient (potentially for loop free) method to achieve the same result

CodePudding user response:

The issue with the loop is not the plotting but the setting of the axis limits with df[column_names].min().min() and df[column_names].max().max().

Testing with 6 main plots, 64 subplots and 375,000 data points, the plotting section of the example takes approx 360s to complete when axis limits are set by searching df for min and max values each loop. However by moving the search for min and max outside the loops. eg

# set y_lims
y_upper = df[column_names].max().max()
y_lower = df[column_names].min().min()

and changing

ax.set_ylim(df[column_names].min().min(),df[column_names].max().max())

to

ax.set_ylim(y_lower,y_upper)

the plotting time is reduced to approx 24 seconds.

  • Related