I would like to do simple division and average using jit function where nopython = True.
import numpy as np
from numba import jit,prange,typed
A = np.array([[2,2,2],[1,0,0],[1,2,1]], dtype=np.float32)
B = np.array([[2,0,2],[0,1,0],[1,2,1]],dtype=np.float32)
C = np.array([[2,0,1],[0,1,0],[1,1,2]],dtype=np.float32)
my jit function goes
@jit(nopython=True)
def test(a,b,c):
mask = a b >0
div = np.divide(c, a b, where=mask)
result = div.mean(axis=1)
return result
test_res = test(A,B,C)
however this throws me an error, what would be the workaround for this? I am trying to do this without the loop, any lights would be appreiciate.
CodePudding user response:
numba doesn't support some arguments for some of numpy modules (e.g. np.mean()
or where
in np.divid
) (including "axis" argument which is not included). You can do this by some alternative codes like:
@nb.njit("float64[::1](float32[:, ::1], float32[:, ::1], float32[:, ::1])") # parallel --> , parallel=True
def test(a, b, c):
result = np.zeros(c.shape[0])
for i in range(c.shape[0]): # parallel --> for i in nb.prange(c.shape[0]):
for j in range(c.shape[1]):
if a[i, j] b[i, j] > 0:
c[i, j] = c[i, j] / (a[i, j] b[i, j])
result[i] = c[i, j]
return result / c.shape[1]