I'm trying to build a function that returns the index of the shortest distance between a data point and a centroid. However I'm getting an error IndexError: arrays used as indices must be of integer (or boolean) type
.
import numpy as np
b = np.ndarray((3, 2, 1, 4))
DEFAULT_CENTROIDS = np.array([[5.664705882352942, 3.0352941176470587, 3.3352941176470585, 1.0176470588235293],
[5.446153846153847, 3.2538461538461543, 2.9538461538461536, 0.8846153846153846],
[5.906666666666667, 2.933333333333333, 4.1000000000000005, 1.3866666666666667],
[5.992307692307692, 3.0230769230769234, 4.076923076923077, 1.3461538461538463],
[5.747619047619048, 3.0714285714285716, 3.6238095238095243, 1.1380952380952383],
[6.161538461538462, 3.030769230769231, 4.484615384615385, 1.5307692307692309],
[6.294117647058823, 2.9764705882352938, 4.494117647058823, 1.4],
[5.853846153846154, 3.215384615384615, 3.730769230769231, 1.2076923076923078],
[5.52857142857143, 3.142857142857143, 3.107142857142857, 1.007142857142857],
[5.828571428571429, 2.9357142857142855, 3.664285714285714, 1.1]])
def get_closest(data_point: np.ndarray, centroids: np.ndarray):
"""
Takes a data_point and a nd.array of multiple centroids and returns the index of the centroid closest to data_point
by computing the euclidean distance for each centroid and picking the closest.
"""
N = centroids.shape[0]
dist = np.empty(N)
for i in centroids:
dist[i] = np.linalg.norm(centroids[i]-data_point)
index_min = np.argmin(dist)
return index_min # the index of the centroid closest to the datapoint
get_closest(b,DEFAULT_CENTROIDS)
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-52-7c41e5374a5e> in <module>
----> 1 get_closest(b,DEFAULT_CENTROIDS)
<ipython-input-51-8e1cd568d0df> in get_closest(data_point, centroids)
11 dist = np.empty(N)
12 for i in centroids:
---> 13 dist[i] = np.linalg.norm(centroids[i]-data_point)
14 index_min = np.argmin(dist)
15 return index_min # the index of the centroid closest to the datapoint
IndexError: arrays used as indices must be of integer (or boolean) type
I don't quite understand the error or why my code is wrong. Suggestions? Thanks in advance.
CodePudding user response:
for i in centroids:
-> centroids[i]
.
i
is an element of centroids
, not the indices of it. So when you try to index by an array of floats it doesn't work.
For what you seem to want, I think enumerate
would be your best bet.
def get_closest(data_point: np.ndarray, centroids: np.ndarray):
"""
Takes a data_point and a nd.array of multiple centroids and returns the index of the centroid closest to data_point
by computing the euclidean distance for each centroid and picking the closest.
"""
N = centroids.shape[0]
dist = np.empty(N)
for i, c in enumerate(centroids):
dist[i] = np.linalg.norm(c - data_point)
index_min = np.argmin(dist)
return index_min # the index of the centroid closest to the datapoint
That said, you'd probably in this case be better off using a KDTree
if centroids
is large or you're aiming for speed.
from scipy.spatial import KDTree
centroid_KDTree = KDTree(centroids)
dist, index_min = centroid_KDTree.query(data_point)