I am trying to use numba to JIT-compile a function in no-Python-mode that has multiple return values, e.g.
def foo(ii: int) -> tuple[int, int, float, float, int]:
return 1 ii, 2 ii, 3.0 ii, 4.0 ii, 5 ii
The closest I can think of is the following:
import numpy as np
import numba as nb
@nb.njit(
nb.from_dtype(np.dtype('i8,i8,f8,f8,i8'))[:](nb.int64) # signature
)
def foo(ii):
return np.array(
[(
1 ii, # int64
2 ii, # int64
3.0 np.float64(ii), # float64
4.0 np.float64(ii), # float64
5 ii, # int64
)],
dtype = 'i8,i8,f8,f8,i8', # structured array
)
a, b, c, d, f = foo(0) # for testing
print(a, b, c, d, e, f)
This causes TypingError: Failed in nopython mode pipeline (step: nopython frontend)
.
The critical part of the error message reads as follows (line breaks inserted for readability):
No conversion from
unaligned array(
Record(
f0[type=int64;offset=0],
f1[type=int64;offset=8],
f2[type=float64;offset=16],
f3[type=float64;offset=24],
f4[type=int64;offset=32]
;40;False
), 2d, C
)
unaligned array(
Record(
f0[type=int64;offset=0],
f1[type=int64;offset=8],
f2[type=float64;offset=16],
f3[type=float64;offset=24],
f4[type=int64;offset=32]
;40;False
), 1d, A
)
for '$58return_value.28', defined at None
Digging through the manual of numba, I can not make complete sense of this. One if the key differences is "1d" vs "2d". I have a dimension mismatch, ok, though having tried a ton of variations on the notation, I am unable to fix it.
CodePudding user response:
Keep it simple. The first function can be successfully jitted and returns a tuple. You can specify the signature using the following code:
@nb.njit('Tuple([i8,i8,f8,f8,i8])(i8)')
def foo(ii: int) -> tuple[int, int, float, float, int]:
return 1 ii, 2 ii, 3.0 ii, 4.0 ii, 5 ii