Home > OS >  How to make an np.array in numba with input-dependent rank?
How to make an np.array in numba with input-dependent rank?

Time:03-23

I would like to @numba.njit this simple function that returns an array with a shape, in particular a rank, that depends on the input i: E.g. for i = 4 the shape should be shape=(2, 2, 2, 2, 4)

import numpy as np
from numba import njit

@njit
def make_array_numba(i):
    shape = np.array([2] * i   [i], dtype=np.int64)
    return np.empty(shape, dtype=np.int64)

make_array_numba(4).shape

I tried many different ways, but always fail at the fact that I can't generate the shape tuple that numba wants to see in np.empty / np.reshape / np.zeros /... In normal numpy one can pass lists / np.arrays as the shape, or I can generate a tuple on the fly such as (2,) * i (i,).

Output:

>>> empty(array(int64, 1d, C), dtype=class(int64))
 
There are 4 candidate implementations:
      - Of which 4 did not match due to:
      Overload in function '_OverloadWrapper._build.<locals>.ol_generated': File: numba/core/overload_glue.py: Line 131.
        With argument(s): '(array(int64, 1d, C), dtype=class(int64))':
       Rejected as the implementation raised a specific error:
         TypingError: Failed in nopython mode pipeline (step: nopython frontend)
       No implementation of function Function(<intrinsic stub>) found for signature:
        
        >>> stub(array(int64, 1d, C), class(int64))
        
       There are 2 candidate implementations:
         - Of which 2 did not match due to:
         Intrinsic of function 'stub': File: numba/core/overload_glue.py: Line 35.
           With argument(s): '(array(int64, 1d, C), class(int64))':
          No match.

CodePudding user response:

This is not possible only with @njit. The reason is that Numba needs to set a type for the array independently of variable values so to compile the function and only then execute it. The thing is the dimension of an array is part of its type. Thus, here, Numba cannot find the type of the array since it is dependent of a value that is not a compile-time constant.

The only way to solve this problem (assuming you do not want to linearize your array) is to recompile the function for each possible i which is certainly overkill and completely defeat the benefit of using Numba (at least in your example). Note that @generated_jit can be used in such a case when you really want to recompile the function for different values or input types. I strongly advise you not to use it for your current use-case. If you try, then you will have other similar issues due to the array not being indexable using a runtime-defined variables and the resulting code will quickly be insane.

A more general and cleaner solution is simply to linearize the array. This means flattening it and perform some fancy indexing computation like (((... z) * stride_z) y) * stride_y x. The size and the index can be computed at runtime independently of the typing system. Note that indexing can be quite slow but Numpy will not use a faster code in this case.

  • Related