Home > Software engineering >  How to change functions according to valued in numpy array
How to change functions according to valued in numpy array

Time:01-31

I am trying to make a surface plot of a function that looks like this:

def model(param,x_1,x_2,x_3,x_4):
    est=param[0] param[1]*(x_1 x_2*x_3 x_2**2*x_4)
    return est

The point is that according to the value of x_2, x_3=1 for x_2>=0 and x_4=1 for x_2<0 (else:0).

When I tried to make a surface plot, I was confused how to make the mesh-grid as there are 2 more variables in addition to x_1 and x_2.

To compute the z axis, I tried to modify function as:

def function (param,x_1,x_2):
    if x_2>0:
      est=param[0] param[1]*(x_1 x_2)
    else:
      est=param[0] param[1]*(x_1 x_2**2)
    return est)

However, it says the truth value is ambiguous. I understood it as python sees it whether all values in x_2 >=0 or not.

I also tried to use np.sign(), but it doesn't act in a way I want this case.

Would there be any ways to change the function according to the value of elenemt in the array and/or solve this without manually computing z axis using for loop?

CodePudding user response:

If you want to check all values greater than 0, use all:

def function (param,x_1,x_2):
    if all(x_2>0):
      est=param[0] param[1]*(x_1 x_2)
    else:
      est=param[0] param[1]*(x_1 x_2**2)
    return est

but if you want to apply the test on each value, use np.where:

def function (param,x_1,x_2):
    return np.where(x_2 > 0,
                    param[0] param[1]*(x_1 x_2),
                    param[0] param[1]*(x_1 x_2**2))

CodePudding user response:

I think you need numpy.where:

def function (param,x_1,x_2):
    return np.where(x_2>0, 
                    param[0] param[1]*(x_1 x_2), 
                    param[0] param[1]*(x_1 x_2**2))

How it working:

param = [10,8]
x_1 = np.array([1,2,3])
x_2 = np.array([0,4,10])

If Trues in mask are used values from param[0] param[1]*(x_1 x_2) else from param[0] param[1]*(x_1 x_2**2):

print (x_2>0)
[False  True  True]

print (param[0] param[1]*(x_1 x_2))
[ 18  58 114]

print (param[0] param[1]*(x_1 x_2**2))
[ 18 154 834]

print (function(param,x_1,x_2))
[ 18  58 114]
  • Related