Home > Mobile >  Why is my Numba JIT function recognized as an array?
Why is my Numba JIT function recognized as an array?

Time:04-07

So I am trying to speed up a code, and while it works for most functions, a few do not work. I specify the function's signature, and it doesn't work. If I write only nb.njit, it works, but there is no speed up at all, there is even a slight slow down. When specifying the signature as in the code I posted below, I get the following error:

TypeError: The decorated object is not a function (got type <class 'numba.core.types.npytypes.Array'>).


nb.njit(nb.int32[:],nb.int32[:],nb.int32[:],nb.int32[:,;](nb.int8,nb.int32[:,:],nb.int32[:,:],nb.int32[:,:],nb.int64,nb.float64[:,:])
def IdentifyVertsAndCells(i,cell_verts,vert_cells,vert_neighs,T1_edge,l):
    #Find vertices undergoing T1
    if T1_edge == len(l) - 1:
        T1_verts = cell_verts[i,[T1_edge,0]]
    else:
        T1_verts = cell_verts[i,[T1_edge,T1_edge 1]]
    
    #Identify the four cells are affected by transition
    dummy = np.concatenate((vert_cells[T1_verts[0]],vert_cells[T1_verts[1]]))
    T1cells = np.unique(dummy)

    #Identify cells that are neighbours prior to transition (that won't be afterwards)
    old_neigh = np.intersect1d(vert_cells[T1_verts[0]],vert_cells[T1_verts[1]])
    
    #Identify cells that will be neighbours after transition
    notneigh1 = T1cells[T1cells != old_neigh[0]]
    notneigh2 = T1cells[T1cells != old_neigh[1]]
    new_neigh = np.intersect1d(notneigh1,notneigh2)
    old_vert_neighs = vert_neighs[T1_verts,:]
    return T1_verts, old_neigh, new_neigh, old_vert_neighs

I checked the sizes and data types of my input arrays and number and am sure I did not make a mistake there. I want to add that for the number of type int8, I had to change an int to an int8 using j = np.asarray([i],dtype='int8')[0] because I didn't find a type for int, but I did for int8. The input number i in my code corresponds to that j and is indeed of type int8. When I only use inspect.isfunction on my function, it recognizes it as a function.

Here is the code calling the above function:

def UpdateTopology(points,verts,vert_neighs,vert_cells,cell_verts,l,x_max,y_max,T1_thresh,N):
    for i in range(N):
        #Determine how many vertices are in cell i
        vert_inds = cell_verts[i,:] < 2*N
        j = np.array([i]).astype('int8')[0]
        #If cell i only has three sides don't perform a T1 on it (you can't have a cell with 2 sides) 
        if(len(vert_inds) == 3):
            continue 
        
        #Find UP TO DATE vertex coords of cell i (can't use cell_vert_coords as
        #vertices will have changed due to previous T1 transitions)
        vert_inds = cell_verts[i,:] < 2*N
        cell_i = verts[cell_verts[i,vert_inds],:]
        
        #Shift vertex coords to account for periodicity
        rel_dists = cell_i - points[i,:]
        cell_i = ShiftCoords(cell_i,rel_dists,x_max,y_max,4)
        
        #Calculate the lengths, l, of each cell edge
        shifted_verts = np.roll(cell_i,-1,axis=0)
        l =  shifted_verts - cell_i
        l_mag = np.linalg.norm(l,axis=1)
        #Determine if any edges are below threshold length
        to_change = np.nonzero(l_mag < T1_thresh)
        #print('l = ',l)
        #print('T1_thresh = ', T1_thresh)
        if len(to_change[0]) == 0:
            continue
        else:
            T1_edge = to_change[0][0]
            #Identify vertices and cells affected, also return vert_neighs of old neighbours (to be used when updating vert_neighs later on)
            T1_verts, old_neigh, new_neigh, old_vert_neighs = T1f.IdentifyVertsAndCells(i,cell_verts,vert_cells,vert_neighs,T1_edge,l)
            #Update vertex coordinates
            verts = T1f.UpdateVertCoords(j,verts,points,cell_i,old_neigh,T1_verts,T1_thresh,l,T1_edge,x_max,y_max)    
            #Update vert_cells
            vert_cells = T1f.UpdateVertCells(verts,points,vert_cells,T1_verts,old_neigh,new_neigh)
            #Update cell_verts 
            cell_verts = T1f.UpdateCellVerts(verts,points,cell_verts,T1_verts,old_neigh,new_neigh,N)
            #Update vert_neighs
            vert_neighs = T1f.UpdateVertNeighs(vert_neighs,points,cell_verts,T1_verts,old_neigh,new_neigh,old_vert_neighs,N)

    return verts, vert_neighs, vert_cells, cell_verts

CodePudding user response:

The error you're getting is related to the wrongly-defined signature that you're providing to the njit decorator. Whenever your jitted function returns several values, you have to define the return type as either an homogeneous tuple (if all the return types are the same) or heterogeneous tuple (if the return types are different) (see this answer).

Regarding the speedup, you won't get any with this code sample: rather you'll get a slowdown. The main reasons I can recognize are the following two:

  1. You are only resorting on numpy's standard functions, that are already highly optimized. As a rule of thumb, numba works well in speeding up loops. If you're code is already vectorized, you probably won't get any speed improvement by njitting your function.
  2. You're using a lot of fancy indexing, e.g. cell_verts[i,[T1_edge,T1_edge 1]], which produce array copies that require memory allocation, a task to which numba is not really good at (see this answer).
  • Related