Home > Software engineering >  Python np.where not functioning consistently
Python np.where not functioning consistently

Time:12-09

I am using np.where to test the value of a field('Security') in the row (a) of a datatable as if the value is 'Cash" I want to do something and if it is not then do something else. I am getting different behaviours depending on what I am putting into the then and else arguments of the function. My example is below as returned from my debugger (I am using spyder and python 3.9.13

np.where(a['Security']=='CASH', 1, 2)
Out[233]: array(2)
np.where(a['Security']=='CASH', re.findall('|'.join(aSecs), a['Description'].upper()), 2)
Out[234]: array([], dtype=float64)

and the value of the field I am testing

a['Security']
Out[235]: 'NOTCASH'

The return of

re.findall('|'.join(aSecs), a['Description'].upper())
Out[236]: []

What could be causing the np.where to evaluate the same clause differently?

CodePudding user response:

From the first, a['Security']=='CASH' must be evaluating as a scalar False:

In [281]: np.where(False, 1, 2)
Out[281]: array(2)

If one the choices is a (0,) shape array, the result will also be (0,) shape:

In [282]: np.where(False, np.array([]), 2)
Out[282]: array([], dtype=float64)

The key is that np.where broadcasts the three arguments against each other. In broadcasting, all size 1 dimensional terms (and scalars) as scaled to match the others, even size 0.

Scaling up to (2,) is something we often see, but scaling down to (0,) is also allowed:

In [283]: np.where(False, np.array([1,2]), 2)
Out[283]: array([2, 2])

If the condition is (2,), the result will also be (2,): In [284]: np.where([True,False],2,3) Out[284]: array([2, 3])

But with a (0,), broadcasting fails:

In [285]: np.where([True,False],[],3)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [285], in <cell line: 1>()
----> 1 np.where([True,False],[],3)

File <__array_function__ internals>:5, in where(*args, **kwargs)

ValueError: operands could not be broadcast together with shapes (2,) (0,) () 

np.where maybe more complex than you need for just one row.

In [286]: [] if False else 2
Out[286]: 2

In [287]: [] if True else 2
Out[287]: []
  • Related