So I am trying to speed up a code, and while it works for most functions, a few do not work. I specify the function's signature, and it doesn't work. If I write only nb.njit, it works, but there is no speed up at all, there is even a slight slow down. When specifying the signature as in the code I posted below, I get the following error:
TypeError: The decorated object is not a function (got type <class 'numba.core.types.npytypes.Array'>).
nb.njit(nb.int32[:],nb.int32[:],nb.int32[:],nb.int32[:,;](nb.int8,nb.int32[:,:],nb.int32[:,:],nb.int32[:,:],nb.int64,nb.float64[:,:])
def IdentifyVertsAndCells(i,cell_verts,vert_cells,vert_neighs,T1_edge,l):
#Find vertices undergoing T1
if T1_edge == len(l) - 1:
T1_verts = cell_verts[i,[T1_edge,0]]
else:
T1_verts = cell_verts[i,[T1_edge,T1_edge 1]]
#Identify the four cells are affected by transition
dummy = np.concatenate((vert_cells[T1_verts[0]],vert_cells[T1_verts[1]]))
T1cells = np.unique(dummy)
#Identify cells that are neighbours prior to transition (that won't be afterwards)
old_neigh = np.intersect1d(vert_cells[T1_verts[0]],vert_cells[T1_verts[1]])
#Identify cells that will be neighbours after transition
notneigh1 = T1cells[T1cells != old_neigh[0]]
notneigh2 = T1cells[T1cells != old_neigh[1]]
new_neigh = np.intersect1d(notneigh1,notneigh2)
old_vert_neighs = vert_neighs[T1_verts,:]
return T1_verts, old_neigh, new_neigh, old_vert_neighs
I checked the sizes and data types of my input arrays and number and am sure I did not make a mistake there. I want to add that for the number of type int8, I had to change an int to an int8 using j = np.asarray([i],dtype='int8')[0] because I didn't find a type for int, but I did for int8. The input number i in my code corresponds to that j and is indeed of type int8. When I only use inspect.isfunction on my function, it recognizes it as a function.
Here is the code calling the above function:
def UpdateTopology(points,verts,vert_neighs,vert_cells,cell_verts,l,x_max,y_max,T1_thresh,N):
for i in range(N):
#Determine how many vertices are in cell i
vert_inds = cell_verts[i,:] < 2*N
j = np.array([i]).astype('int8')[0]
#If cell i only has three sides don't perform a T1 on it (you can't have a cell with 2 sides)
if(len(vert_inds) == 3):
continue
#Find UP TO DATE vertex coords of cell i (can't use cell_vert_coords as
#vertices will have changed due to previous T1 transitions)
vert_inds = cell_verts[i,:] < 2*N
cell_i = verts[cell_verts[i,vert_inds],:]
#Shift vertex coords to account for periodicity
rel_dists = cell_i - points[i,:]
cell_i = ShiftCoords(cell_i,rel_dists,x_max,y_max,4)
#Calculate the lengths, l, of each cell edge
shifted_verts = np.roll(cell_i,-1,axis=0)
l = shifted_verts - cell_i
l_mag = np.linalg.norm(l,axis=1)
#Determine if any edges are below threshold length
to_change = np.nonzero(l_mag < T1_thresh)
#print('l = ',l)
#print('T1_thresh = ', T1_thresh)
if len(to_change[0]) == 0:
continue
else:
T1_edge = to_change[0][0]
#Identify vertices and cells affected, also return vert_neighs of old neighbours (to be used when updating vert_neighs later on)
T1_verts, old_neigh, new_neigh, old_vert_neighs = T1f.IdentifyVertsAndCells(i,cell_verts,vert_cells,vert_neighs,T1_edge,l)
#Update vertex coordinates
verts = T1f.UpdateVertCoords(j,verts,points,cell_i,old_neigh,T1_verts,T1_thresh,l,T1_edge,x_max,y_max)
#Update vert_cells
vert_cells = T1f.UpdateVertCells(verts,points,vert_cells,T1_verts,old_neigh,new_neigh)
#Update cell_verts
cell_verts = T1f.UpdateCellVerts(verts,points,cell_verts,T1_verts,old_neigh,new_neigh,N)
#Update vert_neighs
vert_neighs = T1f.UpdateVertNeighs(vert_neighs,points,cell_verts,T1_verts,old_neigh,new_neigh,old_vert_neighs,N)
return verts, vert_neighs, vert_cells, cell_verts
CodePudding user response:
The error you're getting is related to the wrongly-defined signature that you're providing to the njit
decorator. Whenever your jitted function returns several values, you have to define the return type as either an homogeneous tuple (if all the return types are the same) or heterogeneous tuple (if the return types are different) (see this answer).
Regarding the speedup, you won't get any with this code sample: rather you'll get a slowdown. The main reasons I can recognize are the following two:
- You are only resorting on
numpy
's standard functions, that are already highly optimized. As a rule of thumb,numba
works well in speeding up loops. If you're code is already vectorized, you probably won't get any speed improvement bynjit
ting your function. - You're using a lot of fancy indexing, e.g.
cell_verts[i,[T1_edge,T1_edge 1]]
, which produce array copies that require memory allocation, a task to whichnumba
is not really good at (see this answer).