Home > Software engineering >  How to set numba signature with nested lists?
How to set numba signature with nested lists?

Time:10-25

I'm trying to return an nested list, however running into some conversion error. Below is small piece of code for reproduction of error.

from numba import njit, prange

@njit("ListType(ListType(ListType(int32)))(int32, int32)", fastmath = True, parallel = True, cache = True)
def test(x, y):
    a = []
    for i in prange(10):
        b = []
        for j in range(4):
            c = []
            for k in range(5):
                c.append(k)
            b.append(c)
        a.append(b)
    return a

Error enter image description here

CodePudding user response:

I try to avoid using empty lists with numba, mainly because an empty list cannot be typed. Check out nb.typeof([])

I am not sure whether your output can be preallocated but you could consider arrays. There would also be massive performance benefits. Here is an attempt:

from numba import njit, prange, int32
import numpy as np

@njit(int32[:,:,:](int32, int32), fastmath = True, parallel = True, cache = True)
def test(x, y):
    out = np.zeros((10,x,y), dtype=int32)
    for i in prange(10):
        for j in range(x):
            for k in range(y):
                out[i][j][k] = k
    return out

That said, you might indeed need lists for your application, in which case this answer might not be of much use.

CodePudding user response:

This worked for me.


from numba import njit, prange
from numba.typed import List

@njit(fastmath = True, parallel = True, cache = True)
def test(x, y):
    a = List()
    for i in prange(10):
        b = List()
        for j in range(4):
            c = List()
            for k in range(5):
                c.append(k)
            b.append(c)
        a.append(b)
    return a
  • Related