I am working on my own implementation of the weighted knn algorithm.
To simplify the logic, let's represent this as a predict method, which takes three parameters:
indices - matrix of nearest j neighbors from the training sample for object i (i=1...n, n objects in total). [i, j] - index of object from the training sample. For example, for 4 objects and 3 neighbors:
indices = np.asarray([[0, 3, 1],
[0, 3, 1],
[1, 2, 0],
[5, 4, 3]])
distances - matrix of distances from j nearest neighbors from the training sample to object i. (i=1...n, n objects in total). For example, for 4 objects and 3 neighbors:
distances = np.asarray([[ 4.12310563, 7.07106781, 7.54983444],
[ 4.89897949, 6.70820393, 8.24621125],
[ 0., 1.73205081, 3.46410162],
[1094.09368886, 1102.55022561, 1109.62245832]])
labels - vector with true labels of classes for each object j of training sample. For example:
labels = np.asarray([0, 0, 0, 1, 1, 2])
Thus, the function signature is:
def predict(indices, distances, labels):
....
# return [np.bincount(x).argmax() for x in labels[indices]]
return predict
In the commentary you can see the code that returns the prediction for the "non-weighted" knn-method, which does not use distances. Can you please show, how predictions can be calculated with using the distance matrix? I found the algorithm, but now I'm completely stumped becase I don't know how to realize it with numpy.
Thank you!
CodePudding user response:
This should work:
# compute inverses of distances
# suppress division by 0 warning,
# replace np.inf with a very large number
with np.errstate(divide='ignore'):
dinv = np.nan_to_num(1 / distances)
# an array with distinct class labels
distinct_labels = np.array(list(set(labels)))
# an array with labels of neighbors
neigh_labels = labels[indices]
# compute the weighted score for each potential label
weighted_scores = ((neigh_labels[:, :, np.newaxis] == distinct_labels) * dinv[:, :, np.newaxis]).sum(axis=1)
# choose the label with the highest score
predictions = distinct_labels[weighted_scores.argmax(axis=1)]