Home > Net >  numba.core.errors.TypingError: while using np.random.randint()
numba.core.errors.TypingError: while using np.random.randint()

Time:09-21

How to use np.random.randint with numba as this throws a very large error, https://hastebin.com/kodixazewo.sql

from numba import jit
import numpy as np
@jit(nopython=True)
def foo():
    a = np.random.randint(16, size=(3,3))
    return a
foo()

CodePudding user response:

see here for more detail regarding the nopython var.

from numba import jit
import numpy as np
import warnings

warnings.filterwarnings("ignore")  # suppress NumbaWarning - remove and read for more info
@jit(nopython=False)   # I guess we need the Python interpreter to randomize with more than 2 parameters in np.random.randint()        
def foo():
    a = np.random.randint(16, size=(3,3))
    return a
foo()

CodePudding user response:

You can use np.ndindex to loop over your desired output size and call np.random.randint for each element individually.

Make sure the output datatype is sufficient to support the range of integers from the randint call.

from numba import njit
import numpy as np

@njit
def foo(size=(3,3)):
    
    out = np.empty(size, dtype=np.uint16)
        
    for idx in np.ndindex(size): 
        out[idx] = np.random.randint(16)
        
    return out

This makes it work for any arbitrary shape:

foo(size=(2,2,2))

Results in:

array([[[ 8,  7],
        [15,  2]],

       [[ 4, 13],
        [ 5, 11]]], dtype=uint16)
  • Related