Home > Net >  How to get wrapper creating and saving figure, around plot function?
How to get wrapper creating and saving figure, around plot function?

Time:12-10

I read about wrappers and would like to use it for plotting. I wanted to have a wrapper function that creates and saves figure in my plot functions. However I get the error shown below. Here is my code:

import numpy as np
import matplotlib.pyplot as plt


# figure creating and saving wrapper
def figure_wrapper(func):
    def wrapper(*args,**kwargs):
        fig = plt.figure()

        fig,name = func(fig,*args,**kwargs)

        fig.savefig(name, bbox_inches='tight')
        plt.close(fig)

    return wrapper

# plotting class
class plotting():

    @figure_wrapper
    def plot1(self,fig,x,y):
        name = 'plot1'
        ax = fig.add_subplot()
        ax.plot(x,y)
        return fig, name

    @figure_wrapper
    def scatter1(self,fig,x,y,):
        name = 'scatter1'
        ax = fig.add_subplot()
        ax.scatter(x, y)
        return fig, name

# data for plotting
x = np.linspace(0, 10, 10)
y = np.linspace(20, 30, 10)
x1 = np.linspace(20, 10, 10)
y1 = np.linspace(60, 30, 10)

# execution of class
plotting = plotting()
plotting.plot1(x = x,y= y)
plotting.scatter1(x = x1,y= y1)

The error that I get:

Traceback (most recent call last):
  File "C:\Users\jerzy\Documents\Test_drawing_saving.py", line 52, in <module>
    plotting.plot1(x = x,y= y)
  File "C:\Users\jerzy\Documents\Test_drawing_saving.py", line 10, in wrapper
    fig,name = func(fig,*args,**kwargs)
  File "C:\Users\jerzy\Documents\Test_drawing_saving.py", line 23, in plot
    ax = fig.add_subplot()
AttributeError: 'plotting' object has no attribute 'add_subplot'

CodePudding user response:

I would just declare your fig as an attribute to your plotting class and just pass x and y to your wrapper. See code below:

import numpy as np
import matplotlib.pyplot as plt


# figure creating and saving wrapper
def figure_wrapper(func):
    def wrapper(*args,**kwargs):
        fig,name = func(*args,**kwargs)
        fig.savefig(name, bbox_inches='tight')
        plt.clf()
        plt.close(fig)
    return wrapper

# plotting class
class plotting():
    fig=plt.figure()
    @figure_wrapper
    def plot(self,x,y):
        name = 'plot1'
        ax = self.fig.add_subplot(label=name)
        ax.plot(x,y)
        return self.fig, name

    @figure_wrapper
    def scatter(self,x,y,):
        name = 'scatter1'
        ax = self.fig.add_subplot(label=name)
        ax.scatter(x, y)
        return self.fig, name

# data for plotting
x = np.linspace(0, 10, 10)
y = np.linspace(20, 30, 10)
x1 = np.linspace(20, 10, 10)
y1 = np.linspace(60, 30, 10)

# execution of class
plotting = plotting()
plotting.plot(x,y)
plotting.scatter(x,y)

The output saves the two following figures:

enter image description here

enter image description here

  • Related