Home > OS >  Matplotlib - Split graph creation into multiple functions
Matplotlib - Split graph creation into multiple functions

Time:07-26

In order to create figures with some same graphs, I would like to define a function per group of graph. These should be called depending on the subfigure provided in order to have these graphs at the right location. Consequently, I would liek to split this code below into separate functions as a code like the one provided after this one.

fig = plt.figure(constrained_layout=True, figsize=(10, 8))

# create top/bottom subfigs
(subfig_t, subfig_b) = fig.subfigures(2, 1, hspace=0.05, height_ratios=[1, 3])

# put ax0 in top subfig
ax0 = subfig_t.subplots()
ax0.set_title('ax0')
subfig_t.supxlabel('xlabel0')

# create left/right subfigs nested in bottom subfig
(subfig_bl, subfig_br) = subfig_b.subfigures(1, 2, wspace=0.1, width_ratios=[3, 1])

# put ax1-ax3 in gridspec of bottom-left subfig
gs = subfig_bl.add_gridspec(nrows=1, ncols=9)
ax1 = subfig_bl.add_subplot(gs[0, :1])
ax2 = subfig_bl.add_subplot(gs[0, 1:6], sharey=ax1)
ax3 = subfig_bl.add_subplot(gs[0, 6:], sharey=ax1)
ax1.set_title('ax1')
ax2.set_title('ax2')
ax3.set_title('ax3')
ax2.get_yaxis().set_visible(False)
ax3.get_yaxis().set_visible(False)
subfig_bl.supxlabel('xlabel1-3')

# put ax4 in bottom-right subfig
ax4 = subfig_br.subplots()
ax4.set_title('ax4')
subfig_br.supxlabel('xlabel4')

Below is the code-like I would like to have, to avoid to write the same code multiple times.

fig = plt.figure(constrained_layout=True, figsize=(10, 8))

# create top/bottom subfigs
(subfig_t, subfig_b) = fig.subfigures(2, 1, hspace=0.05, height_ratios=[1, 3])
(subfig_bl, subfig_br) = subfig_b.subfigures(1, 2, wspace=0.1, width_ratios=[3, 1])

def func1(subfig_t):
# put ax0 in top subfig
ax0 = subfig_t.subplots()
ax0.set_title('ax0')
subfig_t.supxlabel('xlabel0')
return subfig_t

def func2(subfig_bl):
# put ax1-ax3 in gridspec of bottom-left subfig
gs = subfig_bl.add_gridspec(nrows=1, ncols=9)
ax1 = subfig_bl.add_subplot(gs[0, :1])
ax2 = subfig_bl.add_subplot(gs[0, 1:6], sharey=ax1)
ax3 = subfig_bl.add_subplot(gs[0, 6:], sharey=ax1)
ax1.set_title('ax1')
ax2.set_title('ax2')
ax3.set_title('ax3')
ax2.get_yaxis().set_visible(False)
ax3.get_yaxis().set_visible(False)
subfig_bl.supxlabel('xlabel1-3')
return subfig_bl

def func3(subfig_br):
# put ax4 in bottom-right subfig
ax4 = subfig_br.subplots()
ax4.set_title('ax4')
subfig_br.supxlabel('xlabel4')
return subfig_bl

def func_save(fig, OutputPath):
fig.savefig(OutputPath, dpi=300, format='png', bbox_inches='tight')

subfig_t = func1(subfig_t)
subfig_bl = func2(subfig_bl)
subfig_br = func3(subfig_br)
func_save(fig, OutputPath)

CodePudding user response:

The functions are not defined as functions, few of the syntax changes and the code is good to run. Python syntax is quite different from other programming languages. It is very simple to learn, and even complex to understand the unknown.

The below code will run perfectly, hope you find it useful.

import numpy as np
import matplotlib.pyplot as plt


fig = plt.figure(constrained_layout=True, figsize=(10, 8))

# create top/bottom subfigs
(subfig_t, subfig_b) = fig.subfigures(2, 1, hspace=0.05, height_ratios=[1, 3])
(subfig_bl, subfig_br) = subfig_b.subfigures(1, 2, wspace=0.1, width_ratios=[3, 1])


def func1(subfig_t):
    # put ax0 in top subfig
    ax0 = subfig_t.subplots()
    ax0.set_title('ax0')
    subfig_t.supxlabel('xlabel0')
    return subfig_t


def func2(subfig_bl):
    # put ax1-ax3 in gridspec of bottom-left subfig
    gs = subfig_bl.add_gridspec(nrows=1, ncols=9)
    ax1 = subfig_bl.add_subplot(gs[0, :1])
    ax2 = subfig_bl.add_subplot(gs[0, 1:6], sharey=ax1)
    ax3 = subfig_bl.add_subplot(gs[0, 6:], sharey=ax1)
    ax1.set_title('ax1')
    ax2.set_title('ax2')
    ax3.set_title('ax3')
    ax2.get_yaxis().set_visible(False)
    ax3.get_yaxis().set_visible(False)
    subfig_bl.supxlabel('xlabel1-3')
    return subfig_bl

def func3(subfig_br):
    # put ax4 in bottom-right subfig
    ax4 = subfig_br.subplots()
    ax4.set_title('ax4')
    subfig_br.supxlabel('xlabel4')
    return subfig_bl

def func_save(fig, OutputPath):
    fig.savefig(OutputPath, dpi=300, format='png', bbox_inches='tight')


# Enter the path for output here
OutputPath = "output.png"

subfig_t = func1(subfig_t)
subfig_bl = func2(subfig_bl)
subfig_br = func3(subfig_br)
func_save(fig, OutputPath)

Happy coding :)

  • Related