I have two Python functions that I am trying to speed up with njit
as they are impacting the performance of my program. Below is a MWE that reproduces the following error when we add the @njit(fastmath=True)
decorator to
f
. Otherwise it works. I believe the error is because the array A
has dtype object. Can I use Numba to decorate f
in addition to g
? If not, what is the fastest way to map g
to the elements of A
? Roughly, the length of A = B ~ 5000. These functions are called around 500 MM times though as part of a hpc workflow.
@njit(fastmath=True)
def g(a, B):
# some function of a and B
return 19.12 / (len(a) len(B))
def f(A, B):
total = 0.0
for i in range(len(B)):
total = g(A[i], B)
return total
A = [[2, 5], [4, 5, 6, 7], [0, 8], [6, 7], [1, 8], [0, 1], [1, 3], [1, 3], [2, 4]]
B = [1, 1, 1, 1, 1, 1, 1, 1, 1]
A = np.array([np.array(a, dtype=int) for a in A], dtype=object)
B = np.array(B, dtype=int)
f(A, B)
TypingError: Failed in nopython mode pipeline (step: nopython frontend) non-precise type array(pyobject, 1d, C) During: typing of argument at /var/folders/9x/hnb8fg0x2p1c9p69p_70jnn40000gq/T/ipykernel_59724/1681580915.py (8)
File "../../../../var/folders/9x/hnb8fg0x2p1c9p69p_70jnn40000gq/T/ipykernel_59724/1681580915.py", line 8: <source missing, REPL/exec in use?>
CodePudding user response:
Can I use Numba to decorate f in addition to g?
No. You cannot use CPython objects in @njit
-decorated Numba function. Numba is mainly fast because of native types (enabling the generation of a fast compiled code as opposed to an interpreted dynamic code).
If not, what is the fastest way to map g to the elements of A?
Jagged arrays are inefficient. In general, a fast solution to this problem is to use 2 arrays: one containing all the values and one containing the start-end range of value for each row (a bit like sparse matrices, but using ranges). Storing the length of each segment also works (and it is more compact) though the start-end ranges need a cumulated-sum which sometimes makes things more complex (eg. the dependencies prevent a straightforward parallelization).
CodePudding user response:
To create the non-jagged array @Jérôme Richard mentions, we can do this:
# Imports.
import numpy as np
from numba import njit, prange
# Data.
A_list = [[2, 5], [4, 5, 6, 7], [0, 8], [6, 7], [1, 8], [0, 1], [1, 3], [1, 3], [2, 4]]
B_list = [1, 1, 1, 1, 1, 1, 1, 1, 1]
A_lenghts = np.array([len(sublist) for sublist in A_list])
dim1 = len(A_list)
dim2 = A_lenghts.max()
A = np.empty(shape=(dim1, dim2), dtype=int) # 9x4.
for i, (sublist, length) in enumerate(zip(A_list, A_lenghts)):
A[i, :length][:] = sublist
B = np.array(B_list, dtype=int)
assert A.shape[0] == B.size
The array A
will look something like this:
array([[ 2, 5, xxxxxx, xxxxxx],
[ 4, 5, 6, 7],
[ 0, 8, xxxxxx, xxxxxx],
[ 6, 7, xxxxxx, xxxxxx],
[ 1, 8, xxxxxx, xxxxxx],
[ 0, 1, xxxxxx, xxxxxx],
[ 1, 3, xxxxxx, xxxxxx],
[ 1, 3, xxxxxx, xxxxxx],
[ 2, 4, xxxxxx, xxxxxx]])
The xxxxxx
are random values that we get because we created the array with np.empty
. This is why you keep A_lengths
as a way to determine where data becomes irrelevant, for each line.
Back to the calculations, I just added the optimizations tof
: the @njit(parallel=True)
decorator and numba.prange
instead of Python's range
.
# Calculations.
@njit(fastmath=True)
def g(a, b):
return 19.12 / (len(a) len(b))
@njit(parallel=True)
def f(A, B):
total = 0.0
for i in prange(len(B)):
total = g(A[i], B)
return total
# Test.
print(f(A, B))