I have some numba-jitted event functions with equal signature, i.e.:
from numba import jit
@jit("Tuple((float64,float64))(float64[::1])", nopython=True)
def event_1(y):
return 1.1, 1.2 # (random values for this example)
@jit("Tuple((float64,float64))(float64[::1])", nopython=True)
def event_2(y):
return 2.1, 2.2 # (random values for this example)
My goal is to create a jitted-function which returns a List of event functions. The event_handler output is supposed to be the input of another cache-compiled jitted-function and so its signature must be fixed:
from numba.typed import List
@jit("ListType(FunctionType(Tuple((float64, float64))(float64[::1])))()", nopython=True)
def event_handler():
return List([event_1, event_2])
However, the code above only works if the List returned by event_handler
has at least two different event functions. If the List has only 1 event-function item, or multiple items of the same function (e.g., either List([event_1])
or List([event_1, event_1])
), the code above cannot compile and produces the following error:
No conversion from ListType[type(CPUDispatcher(<function event_1 at 0x7f83c2a22430>))] to ListType[FunctionType[UniTuple(float64 x 2)(array(float64, 1d, C))]]
I believe the reason is because in the latter case the List item-type is set equal to the function CPUDispatcher rather than the typed function signature. I've already tried this solution to initialise the list but unfortunately it does not work.
How can I solve this issue? It seems strange to me that the list type is inferred correctly when two different functions are provided but suddendly inherits a CPUDispatcher type whenever a single item is given.
CodePudding user response:
I found a solution which evolves explicit typing of the list arguments and the empty_list
method of typed list
. With reference to my original post, I can initialize the List with the empty_list
method, passing as argument its signature (defined outside the scope of any jitted function):
from numba import types
from numba.typed import List
vector_sig = types.Array(dtype=types.float64, ndim=1, layout="C")
function_sig = types.FunctionType(types.Tuple((types.float64, types.float64))(vector_sig))
@jit("ListType(FunctionType(Tuple((float64, float64))(float64[::1])))()", nopython=True)
def event_handler():
event_list = List.empty_list(function_sig)
event_list.append(event_1)
return event_list