Im pretty inexperienced with numba (and posting questions) so hopefully this isn't a miss-specified question.
I am trying to create a jitted function that involves a dictionary. I want the dictionary to have tuples as keys, and floats as values. Below is some code from the numba help found on the numba docs, that I've used to help demonstrate my question.
I understand numba wants variables types to be specified. The problem I think is that I am not specifying the right numba type as the dictionary key inside the function. I've looked at this question but still cant figure out what to do.
import numpy as np
from numba import njit
from numba import types
from numba.typed import Dict
# Make array type. Type-expression is not supported in jit functions.
float_array = types.float64[:]
@njit
def foo():
list_out=[]
# Make dictionary
d = Dict.empty(
key_type=types.Tuple, #<= I suppose im not putting the right 'type' here
value_type=float_array,
)
# an example of how I would like to fill the dictionary
d[(1,1)] = np.arange(3).astype(np.float64)
d[(2,2)] = np.arange(3, 6).astype(np.float64)
list_out.append(d[(2,2)])
return list_out
list_out = foo()
Any help or guidance is appreciated. Thanks for your time!
CodePudding user response:
types.Tuple
is an incomplete type and so not a valid one. You need to specify the type of the items in the tuple. In this case, you can use types.UniTuple(types.int32, 2)
as a complete key type (a tuple containing two 32-bit integers). Here is the resulting code:
import numpy as np
from numba import njit
from numba import types
from numba.typed import Dict
# Make key type with two 32-bit integer items.
key_type = types.UniTuple(types.int32, 2)
# Make array type. Type-expression is not supported in jit functions.
float_array = types.float64[:]
@njit
def foo():
list_out=[]
# Make dictionary
d = Dict.empty(
key_type=key_type,
value_type=float_array,
)
# an example of how I would like to fill the dictionary
d[(1,1)] = np.arange(3).astype(np.float64)
d[(2,2)] = np.arange(3, 6).astype(np.float64)
list_out.append(d[(2,2)])
return list_out
list_out = foo()
By the way, be aware that arange
accept a dtype
in parameter so you can use np.arange(3, dtype=np.float64)
directly which is more efficient when using astype
.