Home > database >  How to return a 1D structured array (mixed types) from a numba-JIT-compiled function?
How to return a 1D structured array (mixed types) from a numba-JIT-compiled function?

Time:06-25

I am trying to use numba to JIT-compile a function in no-Python-mode that has multiple return values, e.g.

def foo(ii: int) -> tuple[int, int, float, float, int]:
    return 1   ii, 2   ii, 3.0   ii, 4.0   ii, 5   ii

The closest I can think of is the following:

import numpy as np
import numba as nb

@nb.njit( 
    nb.from_dtype(np.dtype('i8,i8,f8,f8,i8'))[:](nb.int64)  # signature
)
def foo(ii):
    return np.array(
        [(
            1   ii,  # int64
            2   ii,  # int64
            3.0   np.float64(ii),  # float64
            4.0   np.float64(ii),  # float64
            5   ii,  # int64
        )],
        dtype = 'i8,i8,f8,f8,i8',  # structured array
    )

a, b, c, d, f = foo(0)  # for testing
print(a, b, c, d, e, f)

This causes TypingError: Failed in nopython mode pipeline (step: nopython frontend).

The critical part of the error message reads as follows (line breaks inserted for readability):

No conversion from 

unaligned array(
    Record(
        f0[type=int64;offset=0],
        f1[type=int64;offset=8],
        f2[type=float64;offset=16],
        f3[type=float64;offset=24],
        f4[type=int64;offset=32]
        ;40;False
    ), 2d, C
)
unaligned array(
    Record(
        f0[type=int64;offset=0],
        f1[type=int64;offset=8],
        f2[type=float64;offset=16],
        f3[type=float64;offset=24],
        f4[type=int64;offset=32]
        ;40;False
    ), 1d, A
)

for '$58return_value.28', defined at None

Digging through the manual of numba, I can not make complete sense of this. One if the key differences is "1d" vs "2d". I have a dimension mismatch, ok, though having tried a ton of variations on the notation, I am unable to fix it.

CodePudding user response:

Keep it simple. The first function can be successfully jitted and returns a tuple. You can specify the signature using the following code:

@nb.njit('Tuple([i8,i8,f8,f8,i8])(i8)')
def foo(ii: int) -> tuple[int, int, float, float, int]:
    return 1   ii, 2   ii, 3.0   ii, 4.0   ii, 5   ii
  • Related