How to have a numba jitclass with an argument which can be either a float64 or a float32 ? With functions, the following code works:
import numba
import numpy as np
from numba import njit
from numba.experimental import jitclass
@njit()
def f(a):
print(a.dtype)
return a[0]
a = np.zeros(3)
f(a)
f(a.astype(np.float32))
while trying to use both float32 and float64 with class attributes fails:
@jitclass([('arr', numba.types.float64[:])])
class MyClass():
def __init__(self):
pass
def f(self, a):
self.arr = a
myclass = MyClass()
myclass.f(np.zeros(3))
# following line fails:
myclass.f(np.zeros(3, dtype=np.float32))
Is there a workaround ?
CodePudding user response:
When you call MyClass()
, Numba need to instantiate a class and because Numba only work with well-defined strongly types (this is what makes it fast and so useful), the field of the class need to be typed before the instantiation of an object. Thus, you cannot define the type of MyClass
fields when the method f
is called because this call is made by the CPython interpreter which is dynamic. Note that a class usually have more than one method (otherwise such a class would not be very useful) and this is why partial compilation is not really possible either.
One simple solution to address this problem is simply to use two types:
class MyClass():
def __init__(self):
pass
def f(self, a):
self.arr = a
MyClass_float32 = jitclass([('arr', numba.types.float32[:])])(MyClass)
MyClass_float64 = jitclass([('arr', numba.types.float64[:])])(MyClass)
myclass = MyClass_float32() # Instantiate the class lazily and an object
# `self.arr` is already instantiated here and it has `float32[:]` type.
myclass.f(np.zeros(3, dtype=np.float32))
myclass = MyClass_float64()
myclass.f(np.zeros(3, dtype=np.float64))
CodePudding user response:
Numba supports templated kernels but by custom selection of types like this:
import numpy as np
from numba import generated_jit, types
@generated_jit(nopython=True)
def is_missing(x):
"""
Return True if the value is missing, False otherwise.
"""
if isinstance(x, types.Float):
return lambda x: np.isnan(x)
elif isinstance(x, (types.NPDatetime, types.NPTimedelta)):
# The corresponding Not-a-Time value
missing = x('NaT')
return lambda x: x == missing
else:
return lambda x: Fals
The trick is to use 'generated_jit' annotation.