I would like to use numba
to expedite my code (see MWE below). However, I face NumbaTypeError: unsupported array index type
. What would be the problem & solution?
import numpy as np
import numba as nb
a = np.array([4, 5, 6, 7, 8, 9], dtype=np.int16)
b = np.array([ 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], dtype=np.int16)
c = np.zeros((14, 20, 2), dtype=np.int16)
@nb.njit(fastmath=True)
def printNumbers(a, b, c):
d = c[a.reshape((a.size, 1)), b, :]
print(d)
printNumbers(a, b, c)
CodePudding user response:
Although numba
supports reshape
function, I removed this function and modified the MWE code as follows:
import numpy as np
import numba as nb
a = np.array([4, 5, 6, 7, 8, 9], dtype=np.int16)
b = np.array([ 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], dtype=np.int16)
c = np.zeros((14, 20, 2), dtype=np.int16)
@nb.njit(fastmath=True)
def printNumbers(a, b, c):
d = c[a][:, b]
print(d)
printNumbers(a, b, c)