I am working on a spatial search case for spheres in which I want to find connected spheres. For this aim, I searched around each sphere for spheres that centers are in a (maximum sphere diameter) distance from the searching sphere’s center. At first, I tried to use scipy related methods to do so, but scipy method takes longer times comparing to equivalent numpy method. For scipy, I have determined the number of K-nearest spheres firstly and then find them by cKDTree.query
, which lead to more time consumption. However, it is slower than numpy method even by omitting the first step with a constant value (it is not good to omit the first step in this case). It is contrary to my expectations about scipy spatial searching speed. So, I tried to use some list-loops instead some numpy lines for speeding up using numba prange
. Numba run the code a little faster, but I believe that this code can be optimized for better performances, perhaps by vectorization, using other alternative numpy modules or using numba in another way. I have used iteration on all spheres due to prevent probable memory leaks and …, where number of spheres are high.
import numpy as np
import numba as nb
from scipy.spatial import cKDTree, distance
radii = radii data # shape: (n-spheres, ) must be loaded by np.load('a.npy') or np.loadtxt('radii_large.csv')
poss = poss data # shape: (n-spheres, 3) must be loaded by np.load('b.npy') or np.loadtxt('pos_large.csv', delimiter=',')
rad_max = np.amax(np.hstack(radii))
dia_max = 2 * rad_max
# @nb.jit('float64[:,::1](float64[:,::1], float64[::1])', forceobj=True, parallel=True)
def ends_gap(poss, dia_max):
particle_corsp_overlaps = np.array([], dtype=np.float64)
ends_ind = np.empty([1, 2], dtype=np.int64)
""" using list looping """
# particle_corsp_overlaps = []
# ends_ind = []
# for particle_idx in nb.prange(len(poss)): # by list looping
for particle_idx in range(len(poss)):
unshared_idx = np.delete(np.arange(len(poss)), particle_idx) # <--- relatively high time consumer
poss_without = poss[unshared_idx]
""" # SCIPY method ---------------------------------------------------------------------------------------------
nears_i_ind = cKDTree(poss_without).query_ball_point(poss[particle_idx], r=dia_max, return_sorted=True) # <--- high time consumer
if len(nears_i_ind) > 0:
dist_i, dist_i_ind = cKDTree(poss_without[nears_i_ind]).query(poss[particle_idx], k=len(nears_i_ind)) # <--- high time consumer
if not isinstance(dist_i, float):
dist_i[dist_i_ind] = dist_i.copy()
""" # NUMPY method --------------------------------------------------------------------------------------------
lx_limit_idx = poss_without[:, 0] <= poss[particle_idx][0] dia_max
ux_limit_idx = poss_without[:, 0] >= poss[particle_idx][0] - dia_max
ly_limit_idx = poss_without[:, 1] <= poss[particle_idx][1] dia_max
uy_limit_idx = poss_without[:, 1] >= poss[particle_idx][1] - dia_max
lz_limit_idx = poss_without[:, 2] <= poss[particle_idx][2] dia_max
uz_limit_idx = poss_without[:, 2] >= poss[particle_idx][2] - dia_max
nears_i_ind = np.where(lx_limit_idx & ux_limit_idx & ly_limit_idx & uy_limit_idx & lz_limit_idx & uz_limit_idx)[0]
if len(nears_i_ind) > 0:
dist_i = distance.cdist(poss_without[nears_i_ind], poss[particle_idx][None, :]).squeeze() # <--- relatively high time consumer
# """ # -------------------------------------------------------------------------------------------------------
contact_check = dist_i - (radii[unshared_idx][nears_i_ind] radii[particle_idx])
connected = contact_check[contact_check <= 0]
particle_corsp_overlaps = np.concatenate((particle_corsp_overlaps, connected))
""" using list looping """
# if len(connected) > 0:
# for value_ in connected:
# particle_corsp_overlaps.append(value_)
contacts_ind = np.where([contact_check <= 0])[1]
contacts_sec_ind = np.array(nears_i_ind)[contacts_ind]
sphere_olps_ind = np.where((poss[:, None] == poss_without[contacts_sec_ind][None, :]).all(axis=2))[0] # <--- high time consumer
ends_ind_mod_temp = np.array([np.repeat(particle_idx, len(sphere_olps_ind)), sphere_olps_ind], dtype=np.int64).T
if particle_idx > 0:
ends_ind = np.concatenate((ends_ind, ends_ind_mod_temp))
else:
ends_ind[0, 0], ends_ind[0, 1] = ends_ind_mod_temp[0, 0], ends_ind_mod_temp[0, 1]
""" using list looping """
# for contacted_idx in sphere_olps_ind:
# ends_ind.append([particle_idx, contacted_idx])
# ends_ind_org = np.array(ends_ind) # using lists
ends_ind_org = ends_ind
ends_ind, ends_ind_idx = np.unique(np.sort(ends_ind_org), axis=0, return_index=True) # <--- relatively high time consumer
gap = np.array(particle_corsp_overlaps)[ends_ind_idx]
return gap, ends_ind, ends_ind_idx, ends_ind_org
In one of my tests on 23000 spheres, scipy, numpy, and numba-aided methods finished the loop in about 400, 200, and 180 seconds correspondingly using Colab TPU; for 500.000 spheres it take 3.5 hours. These execution times are not satisfying at all for my project, where number of spheres may be up to 1.000.000 in a medium data volume. I will call this code many times in my main code and seeking for ways that could perform this code in milliseconds (as much as fastest that it could). Is it possible?? I would be appreciated if anyone would speed up the code as it is needed.
Notes:
- This code must be executable with python 3.7 , on CPU and GPU.
- This code must be applicable for data size, at least, 300.000 spheres.
- All numpy, scipy, and … equivalent modules instead of my written modules, which make my code faster significantly, will be upvoted.
I would be appreciated for any recommendations or explanations about:
- Which method could be faster in this subject?
- Why scipy is not faster than other methods in this case and where it could be helpful relating to this subject?
- Choosing between iterator methods and matrix form methods is a confusing matter for me. Iterating methods use less memory and could be used and tuned up by numba and … but, I think, are not useful and comparable with matrix methods (which depends on memory limits) like numpy and … for huge sphere numbers. For this case, perhaps I could omit the iteration by numpy, but I guess strongly that it cannot be handled due to huge matrix size operations and memory leaks.
Prepared sample test data:
Poss data: 23000, 500000
Radii data: 23000, 500000
Line by line speed test logs: for two test cases scipy method and numpy time consumption.
CodePudding user response:
Step 1: better algorithm
First of all, building a k-d tree runs in O(n log n)
time and doing a query runs in O(log n)
time where n
is the number of points. So using a k-d tree seems a good idea at first glance. However, your code build a k-d tree for each point resulting in a O(n² log n)
time. This is why the Scipy solution is slower than the others. The thing is that Scipy does not provide a way to update a k-d tree. It turns out that updating efficiently a k-d tree appears not to be possible. Hopefully, this is not a problem in your case: you can just build one k-d tree with all the points once and then discard the current point you do not want appearing in the result of each query.
Moreover, the computation of sphere_olps_ind
runs in O(n² m)
time where n
is the total number of points and m
is the average number of neighbour (ie. closest points retrieved from the k-d tree query). Assuming there is no duplicate points, then it turns out that sphere_olps_ind
is simply equal to np.sort(contacts_sec_ind)
. The later runs in O(m log m)
which is drastically better.
Additionally, using np.concatenate
in a loop to append value in a Numpy array is slow because it creates a new bigger array for each iteration. Using a list was a good idea, but appending directly Numpy array in a list and then calling np.concatenate
once is much faster.
Here is the resulting code:
def ends_gap(poss, dia_max):
particle_corsp_overlaps = []
ends_ind = [np.empty([1, 2], dtype=np.int64)]
kdtree = cKDTree(poss)
for particle_idx in range(len(poss)):
# Find the nearest point including the current one and
# then remove the current point from the output.
# The distances can be computed directly without a new query.
cur_point = poss[particle_idx]
nears_i_ind = np.array(kdtree.query_ball_point(cur_point, r=dia_max), dtype=np.int64)
assert len(nears_i_ind) > 0
if len(nears_i_ind) <= 1:
continue
nears_i_ind = nears_i_ind[nears_i_ind != particle_idx]
dist_i = distance.cdist(poss[nears_i_ind], cur_point[None, :]).squeeze()
contact_check = dist_i - (radii[nears_i_ind] radii[particle_idx])
connected = contact_check[contact_check <= 0]
particle_corsp_overlaps.append(connected)
contacts_ind = np.where([contact_check <= 0])[1]
contacts_sec_ind = nears_i_ind[contacts_ind]
sphere_olps_ind = np.sort(contacts_sec_ind)
ends_ind_mod_temp = np.array([np.repeat(particle_idx, len(sphere_olps_ind)), sphere_olps_ind], dtype=np.int64).T
if particle_idx > 0:
ends_ind.append(ends_ind_mod_temp)
else:
ends_ind[0][:] = ends_ind_mod_temp[0, 0], ends_ind_mod_temp[0, 1]
ends_ind_org = np.concatenate(ends_ind)
ends_ind, ends_ind_idx = np.unique(np.sort(ends_ind_org), axis=0, return_index=True) # <--- relatively high time consumer
gap = np.concatenate(particle_corsp_overlaps)[ends_ind_idx]
return gap, ends_ind, ends_ind_idx, ends_ind_org
Step 2: optimization
First of all, the query_ball_point
call can be done on all the points at once in parallel by providing poss
to the Scipy method and specifying the parameter workers=-1
. However, note that this requires more memory.
Moreover, Numba can be used to significantly speed up the computation. The parts that can be mainly improved is the computation of the distances and the creation of many unnecessary temporary arrays as well as the use of Numpy array direct indexing instead of list's appends (since the bounded size of the output array can be known after the query_ball_point
call).
Here is a simple example of optimized code using Numba:
@nb.jit('(float64[:, ::1], int64[::1], int64[::1], float64)')
def compute(poss, all_neighbours, all_neighbours_sizes, dia_max):
particle_corsp_overlaps = []
ends_ind_lst = [np.empty((1, 2), dtype=np.int64)]
an_offset = 0
for particle_idx in range(len(poss)):
cur_point = poss[particle_idx]
cur_len = all_neighbours_sizes[particle_idx]
nears_i_ind = all_neighbours[an_offset:an_offset cur_len]
an_offset = cur_len
assert len(nears_i_ind) > 0
if len(nears_i_ind) <= 1:
continue
nears_i_ind = nears_i_ind[nears_i_ind != particle_idx]
dist_i = np.empty(len(nears_i_ind), dtype=np.float64)
# Compute the distances
x1, y1, z1 = poss[particle_idx]
for i in range(len(nears_i_ind)):
x2, y2, z2 = poss[nears_i_ind[i]]
dist_i[i] = np.sqrt((x2-x1)**2 (y2-y1)**2 (z2-z1)**2)
contact_check = dist_i - (radii[nears_i_ind] radii[particle_idx])
connected = contact_check[contact_check <= 0]
particle_corsp_overlaps.append(connected)
contacts_ind = np.where(contact_check <= 0)
contacts_sec_ind = nears_i_ind[contacts_ind]
sphere_olps_ind = np.sort(contacts_sec_ind)
ends_ind_mod_temp = np.empty((len(sphere_olps_ind), 2), dtype=np.int64)
for i in range(len(sphere_olps_ind)):
ends_ind_mod_temp[i, 0] = particle_idx
ends_ind_mod_temp[i, 1] = sphere_olps_ind[i]
if particle_idx > 0:
ends_ind_lst.append(ends_ind_mod_temp)
else:
tmp = ends_ind_lst[0]
tmp[:] = ends_ind_mod_temp[0, :]
return particle_corsp_overlaps, ends_ind_lst
def ends_gap(poss, dia_max):
kdtree = cKDTree(poss)
tmp = kdtree.query_ball_point(poss, r=dia_max, workers=-1)
all_neighbours = np.concatenate(tmp, dtype=np.int64)
all_neighbours_sizes = np.array([len(e) for e in tmp], dtype=np.int64)
particle_corsp_overlaps, ends_ind_lst = compute(poss, all_neighbours, all_neighbours_sizes, dia_max)
ends_ind_org = np.concatenate(ends_ind_lst)
ends_ind, ends_ind_idx = np.unique(np.sort(ends_ind_org), axis=0, return_index=True)
gap = np.concatenate(particle_corsp_overlaps)[ends_ind_idx]
return gap, ends_ind, ends_ind_idx, ends_ind_org
ends_gap(poss, dia_max)
Performance analysis
Here are the performance results on my 6-core machine (with a i5-9600KF processor) on the small dataset:
Initial code with Scipy: 259 s
Initial default code with Numpy: 112 s
Optimized algorithm: 1.37 s
Final optimized code: 0.22 s
Here are the performance results on the big dataset:
Initial code with Scipy: 100000 s (estimation)
Initial default code with Numpy: 6700 s (estimation)
Optimized algorithm: 6.36 s
Final optimized code: 1.28 s
Thus the Numba implementation with an efficient algorithm is up to ~5230 times faster than the initial Numpy implementation and ~78000 time faster than the initial Scipy implementation.
The Numba code can be further optimized, but please note that the Numba compute
call takes less than 25% of the time on my machine. The np.unique
call is the most expensive, but it is not easy to make it faster. A significant part of the time is spent in the Scipy-to-Numba data conversion, but this code is mandatory as long as Scipy is used. Thus, the code can be improved a bit (eg. certainly 2x faster) with advanced Numba optimization but if you need a much faster code, then you need to use a native language like C and an highly-optimized parallel k-d tree implementation. I expect a very-optimized native code to be an order of magnitude faster but not much more. I hardly believe the big dataset can be computed in less than 10 ms on my machine regardless of the implementation.
Notes
Note that gap
is different with the provided functions (other values are left unchanged). However, the same thing happens between the initial Scipy method and the one of Numpy. This appear to come from the ordering of variables like nears_i_ind
and dist_i
which is undefined by Scipy and change the gap
result in a non-trivial way (not just the order of gap
). I am not sure this is a problem of the initial implementation. Because of that, it is much harder to compare the correctness of the different implementations.
forceobj
should not be used in production as the documentation states this is only useful for testing purposes.
CodePudding user response:
Have you tried FLANN?
This code doesn't solve your problem completely. It simply finds the nearest 50 neighbors to each point in your 500000 point dataset:
from pyflann import FLANN
p = np.loadtxt("pos_large.csv", delimiter=",")
flann = FLANN()
flann.build_index(pts=p)
idx, dist = flann.nn_index(qpts=p, num_neighbors=50)
The last line takes less than a second in my laptop without any tuning or parallelization.