I want to plot multiple subplots of scatter plots inside a function, after calling the *args
parameter to unpack (x,y)
input values. However, I keep getting a simple error:
ValueError: s must be a scalar, or float array-like with the same size as x and y
I cannot seem to solve it even after changing the function into alternative orders of args
. Here is my sample code:
import pandas as pd
import numpy as np
from matplotlib.pyplot import plt
x = np.array([[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4],
[0.3, 0.5, 0.6, 0.2, 0.4, 0.5, 0.6, 0.5, 0.8, 0.9, 0.9, 0.8, 0.2, 0.1, 0.5, 0.6],
['r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b']])
values = pd.DataFrame(x.T, columns=['a', 'b', 'c'])
X = values[values['c'] == 'r'].iloc[ : , 0:2 ].values
Y = values[values['c'] == 'b'].iloc[ : , 0:2 ].values
def test(*args):
figs, axs = plt.subplots( 1 , 2 , figsize = ( 8 , 8 ) )
for xy , ax in zip( args , axs.flat ) :
print(xy)
ax.scatter(*xy)
test(X, Y)
plt.show()
CodePudding user response:
I have achieved it with the follow, so perhaps there is a cleaner alternative?
def test(*args):
figs, axs = plt.subplots( 1 , 2 , figsize = ( 8 , 8 ) )
xy = np.array(args)
for x_y , ax in zip( xy , axs.flat ) :
(x, y) = np.hsplit(np.ndarray.flatten(x_y), 2)
ax.scatter(x, y)
test(X, Y)
plt.show()
CodePudding user response:
okay ... this is the solution to your problem ... this is probably as incomprehensible as your code.
def t(*args):
figs, axs = plt.subplots( 1 , 2 , figsize = ( 8 , 8 ) )
for xy , ax in zip( zip(*args) , axs.flat ) :
print(xy)
ax.scatter(*xy)
t(X.transpose(), Y.transpose())
now let's convert this to python code... you should know everything about the function just by looking at the function, so *args is useful for something "extra" than the intended behavior of the function.
- separate X and Y from args, because they are used inside the function explicitly.
- avoid zip unless the arguments are simple, while nesting generators is performant it's bad for code quality, in your case i'd go for enumerate instead since i am indexing into an array, which needs to be as descriptive as possible.
- make your calls descriptive,
*xy
is not descriptive, what is this variable ? is it a tuple ? an ndarray ? what will unpacking it result in ? if you are passing X and Y in it then pass X and Y directly.
def t(*args):
X,Y = args
figs, axs = plt.subplots(1, 2, figsize=(8, 8))
for i, ax in enumerate(axs):
print(X[:, i], Y[:, i])
ax.scatter(X[:, i], Y[:, i])
t(X, Y)