I need to create function nearest_neighbor(src, dst)
, which accepts two arrays of 2D points, and for every point of array A calculates distance and index to closest neighbor from array B.
Example input:
src = np.array([[1,1], [2,2],[3,3],[4,4],[9,9]])
dst = np.array([[6,7],[10,10],[10,20]])
Example output:
(array([7.81024968, 6.40312424, 5. , 3.60555128, 1.41421356]),
array([0, 0, 0, 0, 1]))
With sklearn you can do it like this:
def nearest_neighbor(src, dst):
neigh = NearestNeighbors(n_neighbors=1)
neigh.fit(dst)
distances, indices = neigh.kneighbors(src, return_distance=True)
return distances.ravel(), indices.ravel()
But i need to create it only with numpy. I made it like this:
def nearest_neighbor(src, dst):
distances = []
indices = []
for dot in src:
dists = np.linalg.norm(dst - dot,axis=1)
dist = np.min(dists)
idx = np.argmin(dists)
distances.append(dist)
indices.append(idx)
return np.array(distances), np.array(indices)
But it works slow because of python cycle. How I can make it faster?
CodePudding user response:
You should read on numpy broadcasting:
dist = np.square(src[:,None] - dst).sum(axis=-1) ** .5
idx = dist.argmin(axis=-1)
# array([0, 0, 0, 0, 1])
min_dist = dist[np.arange(len(dist)), idx]
CodePudding user response:
You can use scipy.spatial.distance.cdist
:
from scipy.spatial.distance import cdist
# compute matrix of distances
dist = cdist(src, dst)
# get min distance
closest = dist.argmin(axis=1)
# array([0, 0, 0, 0, 1])
distance = dist[np.arange(src.shape[0]), closest]
#array([7.81024968, 6.40312424, 5. , 3.60555128, 1.41421356])
CodePudding user response:
Using broadcast, src[:, None] - dst
make each row of src
subtract each row of dst
:
>>> def nearest_neighbor(src, dst):
... dist = np.linalg.norm(src[:, None] - dst, axis=-1)
... indices = dist.argmin(-1)
... return dist[np.arange(len(dist)), indices], indices
...
>>> src = np.array([[1,1], [2,2],[3,3],[4,4],[9,9]])
>>> dst = np.array([[6,7],[10,10],[10,20]])
>>> nearest_neighbor(src, dst)
(array([7.81024968, 6.40312424, 5. , 3.60555128, 1.41421356]),
array([0, 0, 0, 0, 1], dtype=int64))