How to use np.random.randint
with numba as this throws a very large error, https://hastebin.com/kodixazewo.sql
from numba import jit
import numpy as np
@jit(nopython=True)
def foo():
a = np.random.randint(16, size=(3,3))
return a
foo()
CodePudding user response:
see here for more detail regarding the nopython
var.
from numba import jit
import numpy as np
import warnings
warnings.filterwarnings("ignore") # suppress NumbaWarning - remove and read for more info
@jit(nopython=False) # I guess we need the Python interpreter to randomize with more than 2 parameters in np.random.randint()
def foo():
a = np.random.randint(16, size=(3,3))
return a
foo()
CodePudding user response:
You can use np.ndindex
to loop over your desired output size and call np.random.randint
for each element individually.
Make sure the output datatype is sufficient to support the range of integers from the randint call.
from numba import njit
import numpy as np
@njit
def foo(size=(3,3)):
out = np.empty(size, dtype=np.uint16)
for idx in np.ndindex(size):
out[idx] = np.random.randint(16)
return out
This makes it work for any arbitrary shape:
foo(size=(2,2,2))
Results in:
array([[[ 8, 7],
[15, 2]],
[[ 4, 13],
[ 5, 11]]], dtype=uint16)