Home > Back-end >  how to use mask using numba @jit
how to use mask using numba @jit

Time:06-15

I would like to do simple division and average using jit function where nopython = True.

import numpy as np
from numba import jit,prange,typed

A = np.array([[2,2,2],[1,0,0],[1,2,1]], dtype=np.float32)
B = np.array([[2,0,2],[0,1,0],[1,2,1]],dtype=np.float32)
C = np.array([[2,0,1],[0,1,0],[1,1,2]],dtype=np.float32)

my jit function goes

@jit(nopython=True)
def test(a,b,c):
    mask = a b >0
    div = np.divide(c, a b, where=mask)
    result = div.mean(axis=1)

    return result

test_res = test(A,B,C)

however this throws me an error, what would be the workaround for this? I am trying to do this without the loop, any lights would be appreiciate.

CodePudding user response:

numba doesn't support some arguments for some of numpy modules (e.g. np.mean() or where in np.divid) (including "axis" argument which is not included). You can do this by some alternative codes like:

@nb.njit("float64[::1](float32[:, ::1], float32[:, ::1], float32[:, ::1])")  # parallel --> , parallel=True
def test(a, b, c):
    result = np.zeros(c.shape[0])
    for i in range(c.shape[0]):     # parallel --> for i in nb.prange(c.shape[0]):
        for j in range(c.shape[1]):
            if a[i, j]   b[i, j] > 0:
                c[i, j] = c[i, j] / (a[i, j]   b[i, j])
            result[i]  = c[i, j]

    return result / c.shape[1]
  • Related