I read about wrappers and would like to use it for plotting. I wanted to have a wrapper function that creates and saves figure in my plot functions. However I get the error shown below. Here is my code:
import numpy as np
import matplotlib.pyplot as plt
# figure creating and saving wrapper
def figure_wrapper(func):
def wrapper(*args,**kwargs):
fig = plt.figure()
fig,name = func(fig,*args,**kwargs)
fig.savefig(name, bbox_inches='tight')
plt.close(fig)
return wrapper
# plotting class
class plotting():
@figure_wrapper
def plot1(self,fig,x,y):
name = 'plot1'
ax = fig.add_subplot()
ax.plot(x,y)
return fig, name
@figure_wrapper
def scatter1(self,fig,x,y,):
name = 'scatter1'
ax = fig.add_subplot()
ax.scatter(x, y)
return fig, name
# data for plotting
x = np.linspace(0, 10, 10)
y = np.linspace(20, 30, 10)
x1 = np.linspace(20, 10, 10)
y1 = np.linspace(60, 30, 10)
# execution of class
plotting = plotting()
plotting.plot1(x = x,y= y)
plotting.scatter1(x = x1,y= y1)
The error that I get:
Traceback (most recent call last):
File "C:\Users\jerzy\Documents\Test_drawing_saving.py", line 52, in <module>
plotting.plot1(x = x,y= y)
File "C:\Users\jerzy\Documents\Test_drawing_saving.py", line 10, in wrapper
fig,name = func(fig,*args,**kwargs)
File "C:\Users\jerzy\Documents\Test_drawing_saving.py", line 23, in plot
ax = fig.add_subplot()
AttributeError: 'plotting' object has no attribute 'add_subplot'
CodePudding user response:
I would just declare your fig
as an attribute to your plotting
class and just pass x
and y
to your wrapper.
See code below:
import numpy as np
import matplotlib.pyplot as plt
# figure creating and saving wrapper
def figure_wrapper(func):
def wrapper(*args,**kwargs):
fig,name = func(*args,**kwargs)
fig.savefig(name, bbox_inches='tight')
plt.clf()
plt.close(fig)
return wrapper
# plotting class
class plotting():
fig=plt.figure()
@figure_wrapper
def plot(self,x,y):
name = 'plot1'
ax = self.fig.add_subplot(label=name)
ax.plot(x,y)
return self.fig, name
@figure_wrapper
def scatter(self,x,y,):
name = 'scatter1'
ax = self.fig.add_subplot(label=name)
ax.scatter(x, y)
return self.fig, name
# data for plotting
x = np.linspace(0, 10, 10)
y = np.linspace(20, 30, 10)
x1 = np.linspace(20, 10, 10)
y1 = np.linspace(60, 30, 10)
# execution of class
plotting = plotting()
plotting.plot(x,y)
plotting.scatter(x,y)
The output saves the two following figures: