Home > Software design >  numba jit: failed type inference due to: non-precise type pyobject
numba jit: failed type inference due to: non-precise type pyobject

Time:12-18

I try to speed up the following function that saves the index position of the different labels present in a (large) 3 dimensional np.ndarray using jit.

import numpy as np
from numba import jit

@jit
def find_position(x, values):
    for i in range(x.shape[0]):
        for j in range(x.shape[1]):
            for k in range(x.shape[2]):
                values[x[i,j,k]].append((i,j))
    return values

labels = [1,2,3]
values = {l: [] for l in labels}
x = np.random.choice(labels,1000).reshape((10,10,10))

v = find_position(x, values)

However, I end up with the following error message:

Compilation is falling back to object mode WITH looplifting enabled because Function "find_position" failed type inference due to: non-precise type pyobject
During: typing of argument at <stdin> (3)

Does anyone has any tips how to get around that?

CodePudding user response:

Numba cannot operate on CPython objects like lists and dicts. They are called reflected lists. This is because Numba needs to define statically the type of the objets (and they must be homogeneous) so to generate a fast native code, while CPython objects are dynamically typed and lists/dicts data structures can contain heterogeneous objects. This is also what makes CPython data structure slow compared to Numba. The two approaches are incompatible. However, CPython can deal with typed data structure at the expense of additional overheads. Numba can only operate on typed data structure that can be created from non-typed (ie. reflected) ones. The creation of such data structure is pretty slow. In the end, it may not worth it.

You can create and fill a typed dictionary like this:

import numba as nb

TupleType = nb.types.UniTuple(nb.int64, 2)
ValueType = nb.types.ListType(TupleType)
values = nb.typed.typeddict.Dict.empty(nb.int64, ValueType)
for l in labels:
    values[l] = nb.typed.typedlist.List.empty_list(TupleType)

Since values is typed, you can then pass it to Numba.

Note that dictionary are generally slow (whatever the language) as pointed out by @Rafnus. Thus, it is better to use arrays if you can, especially if the labels are small integers: you can build an array of labelMax 1 items where labelMax is the biggest label value. This assume that the label IDs are positive one. If the number of label is very small (eg. <4) and known at compile time, then the method of @Rafnus may be faster.

CodePudding user response:

When using numba, there are a couple of things that are important to know:

  • Append() is very slow! append() uses a lot of memory. It is much better to create an empty numpy array, and then slowly fill it with values.
  • Try to use numpy functions when working with numba, as numba generally has very good support for them.
  • Avoid dictionairies! Numba technically has limited support for them since update 0.52 [see the documentation here] but they are not required to solve your problem, and for most problems there are often better (more performant) options available.
  • Do not use @jit, use @njit instead! @jit calls numba in object mode, which means it handles all values as Python objects and uses the Python C API to perform all operations on those objects. This basically means that a function decorated with @jit is not faster than normal python code. Only when you use @njit (short for no-python) will you see a performance increase.

Taking all this in mind, i rewrote your code to this:

import numpy as np
from numba import njit

@njit
def find_position(inpt):
    # the 400 here is chosen arbitrarily. In this example, the final arrays will have a 
    # length of somewhere around 310-350, so with 400 we can be certain that
    # the arrays will be large enough to hold all the values
    arr1 = np.zeros((400, 2))
    arr2 = np.zeros((400, 2))
    arr3 = np.zeros((400, 2))
    x, y, z = 0, 0, 0
    for i in range(inpt.shape[0]):
        for j in range(inpt.shape[1]):
            for k in range(inpt.shape[2]):
                val = inpt[i,j,k]
                if val == 1:
                    arr1[x] = [i, j]
                    x  = 1
                elif val == 2:
                    arr2[y] = [i, j]
                    y  = 1
                else:
                    arr3[z] = [i, j]
                    z  = 1
    return arr1[:x], arr2[:y], arr3[:z]

labels = [1,2,3]
np.random.seed(4)
inpt = np.random.choice(labels,size = (10, 10, 10))
v = find_position(inpt)
  • Related