Home > OS >  NumPy: How to elegantly index an array that may be 0-dimensional?
NumPy: How to elegantly index an array that may be 0-dimensional?

Time:06-22

Minimal example:

import numpy as np

def foo(arr):
    negative = arr < 0
    arr2 = arr   1
    arr2[negative] *= -1
    return arr2

a = np.array([1])
b = np.array(1)
print(foo(a)) # works, also works for any other nonzero dimensional array
print(foo(b)) # TypeError: 'numpy.int64' object does not support item assignment

Is there an elegant way to make foo support both zero and nonzero dimensional arrays? I would prefer to use something 'pythonic' rather than branching based on arr2.ndim.

EDIT: this version of foo seems to work without the above issue, though it modifies the input. Why does this work but not the above?

def foo(arr):
    negative = arr < 0
    arr  = 1
    arr[negative] *= -1
    return arr

CodePudding user response:

What about using numpy.atleast1d?

def foo(arr):
    negative = arr < 0
    arr2 = np.atleast_1d(arr   1)
    arr2[negative] *= -1
    return arr2

foo(np.array(-5))

output: array([4])

CodePudding user response:

You could increase the dimensions of the array by 1, perform the operations and then reduce dimensions by 1:

def foo(arr):
    arr = np.array([arr])
    negative = arr < 0
    arr2 = arr   1
    arr2[negative] *= -1
    return np.array(*arr2)
foo(np.array(-5))

Output:

array(4)

CodePudding user response:

There is an issue where arr2 is not WRITEABLE, while arr is:

for x in (np.array(1), np.array(1)   1):
    print(x.dtype, x.ndim)
    print(x.flags)
# int64 0
#   C_CONTIGUOUS : True
#   F_CONTIGUOUS : True
#   OWNDATA : True
#   WRITEABLE : True
#   ALIGNED : True
#   WRITEBACKIFCOPY : False
#   UPDATEIFCOPY : False

# int64 0
#   C_CONTIGUOUS : True
#   F_CONTIGUOUS : True
#   OWNDATA : True
#   WRITEABLE : False
#   ALIGNED : True
#   WRITEBACKIFCOPY : False
#   UPDATEIFCOPY : False

This is causing the assignment expression to fail.


There are a number of workaround for this:

  • copy the input beforehand so that you can use = operator to modify a copy of input, which is still writeable
def foo_cp(arr):
    arr = arr.copy()
    negative = arr < 0
    arr  = 1
    arr[negative] *= -1
    return arr
  • use branching as you originally indicated:
def foo_if(arr):
    negative = arr < 0
    arr2 = arr   1
    if negative.ndim > 0:
        arr2[negative] *= -1
    elif arr < 0:
        arr2 = -arr2
    return arr2
  • use an equivalent algebraic expression
def foo_float(arr):
    return np.sign(np.sign(arr)   0.5).astype(np.int8) * (arr   1)

for integers, the above can be simplified to:

def foo_int(arr):
    return np.abs(arr   1)

Below some tests to show how eventually all these work and produce the correct output for the indicated use cases:

funcs = foo_cp, foo_if, foo_float, foo_int, foo_atl1d, foo_updown

int_arrs = [np.arange(-20, 20), np.array([2]), np.array(2), np.array([-1]), np.array(-1), np.array(-2)]
float_arrs = [np.arange(-20, 20) / np.pi, np.array([1.2]), np.array(1.2), np.array([-0.5]), np.array(-0.5), np.array(1.2)]

int_base = [funcs[0](arr) for arr in int_arrs]
float_base = [funcs[0](arr) for arr in float_arrs]
for func in funcs:
    int_res = [func(arr) for arr in int_arrs]
    float_res = [func(arr) for arr in float_arrs]
    is_good_int = all(np.allclose(x, y) for x, y in zip(int_base, int_res))
    is_good_float = all(np.allclose(x, y) for x, y in zip(float_base, float_res))
    print(f"{func.__name__:>12} {is_good_int!s:>5} {is_good_float!s:>5} ", end="")
    %timeit -n 64 -r 4 [func(arr) for arr in int_arrs]; [func(arr) for arr in float_arrs]
#       foo_cp  True  True 64 loops, best of 4: 74.6 µs per loop
#       foo_if  True  True 64 loops, best of 4: 53.7 µs per loop
#    foo_float  True  True 64 loops, best of 4: 85.8 µs per loop
#      foo_int  True False 64 loops, best of 4: 28.4 µs per loop
#    foo_atl1d  True  True 64 loops, best of 4: 177 µs per loop
#   foo_updown  True  True 64 loops, best of 4: 115 µs per loop

(I have also included foo_atl1d() from @mozway's answer and foo_updown() from @Nin17's answer).

Note that while foo_cp() or foo_float() may be simpler to write, they are likely slower than foo_if(). On the other hand, foo_int() is likely the fastest. foo_atl1d() and foo_updown() seems to be the slowest after all (for the tested inputs).

  • Related