I meet a problem about using numba jit decorator (@nb.jit
)! Here is the warning from jupyter notebook,
NumbaWarning: Compilation is falling back to object mode WITH looplifting enabled because Function "get_nb_freq" failed type inference due to: No implementation of function Function(<function dot
Here is my code:
@numba.jit def get_nb_freq( nb_count = None, onehot_ct = None): # nb_freq = onehot_ct.T @ nb_count nb_freq = np.dot(onehot_ct.T, nb_count) res = nb_freq/nb_freq.sum(axis = 1).reshape(Num_celltype,-1) return res ## onehot_ct is array, and its shape is (921600,4) ## nb_count is array, and its shape is the same as onehot_ct ## Num_celltype equals 4
CodePudding user response:
Based on your mentioned shapes we can create the arrays as:
onehot_ct = np.random.rand(921600, 4) nb_count = np.random.rand(921600, 4)
Your prepared code will be work correctly and get an answer like:
[[0.25013102754197963 0.25021461207825463 0.2496806287276126 0.24997373165215303] [0.2501574139037384 0.25018726649940737 0.24975108864220968 0.24990423095464467] [0.25020550587624757 0.2501303498983212 0.24978335463279314 0.24988078959263807] [0.2501855533482036 0.2500913419625523 0.24979681404573967 0.24992629064350436]]
So, it shows the code is working and the problem seems to be related to type of the arrays, that numba can not recognize them. So, signature may be helpful here, which by we can recognize the types manually for the function. So, based on the error I think the following signature will pass your issue:
@nb.jit("float64[:, ::1](float64[:, ::1], float32[:, ::1])") def get_nb_freq( nb_count = None, onehot_ct = None): nb_freq = np.dot(onehot_ct.T, nb_count) res = nb_freq/nb_freq.sum(axis=1).reshape(4, -1) return res
But it will stuck again if you test by
get_nb_freq(nb_count.astype(np.float64), onehot_ct.astype(np.float32))
, So another cause could be related to unequal types innp.dot
. So, use theonehot_ct
array as array typenp.float64
, could pass the issue:@nb.jit("float64[:, ::1](float64[:, ::1], float32[:, ::1])") def get_nb_freq( nb_count, onehot_ct): nb_freq = np.dot(onehot_ct.astype(np.float64).T, nb_count) res = nb_freq/nb_freq.sum(axis=1).reshape(4, -1) return res
It ran on my machine with this correction. I recommend write numba equivalent codes (like this for
np.dot
) instead usingnp.dot
or …, which can be much faster.