Home > other >  Use NumPy to generate a kneighbors_graph like scikit-learn?
Use NumPy to generate a kneighbors_graph like scikit-learn?

Time:06-10

I'm trying to brush up on my understanding of some basic ML method. Can someone explain what is going on under the hood with kneighbors_graph? I would like to replicate this output with only NumPy.

from sklearn.neighbors import kneighbors_graph

X = [[0, 1], [3, 4], [7, 8]]
A = kneighbors_graph(X, 2, mode='distance', include_self=True)
A.toarray()

Output:

array([[0.        , 4.24264069, 0.        ],
       [4.24264069, 0.        , 0.        ],
       [0.        , 5.65685425, 0.        ]])

CodePudding user response:

The resulting matrix represents the distance-weighted graph of n = 2 neighbours for each point in X, where you are including a point as its own neighbour (with a distance of zero). Note that distances to non-neighbours are also zero, so you might want to check the connectivity graph to know if you're looking at a zero-distance neighbour or a non-neighbour.

Let's start with the first row, representing the first point, [0, 1]. Here's what the numbers in that row mean:

  • The first 0 is the distance to the nearest point, which is itself (because you specified include_self=True). If you specified mode='connectivity' this would be a 1, because it's a neighbour.
  • The second element, 4.24, is the Euclidean distance (aka L2 norm) to the next point in X, which is [3, 4]. You get this distance because metric='minkowski', p=2 are defaults; if you want a different distance metric, you can have it. Again, if you specified mode='connectivity' this would also be a 1, because it's a neighbour.
  • The third element, another 0, is not really a distance; it's telling you that the third point, [7, 8], is not a neighbour when n_neighbors is 2. If you specified mode='connectivity' this would be a 0, because it's not a neighbour.

You can compute the distances between all pairs of points in an array with scipy.spatial.distance.cdist(X, X). There's also scipy.spatial.KDTree for neighbour lookup. If you really want to go pure NumPy, check out the linalg module.

  • Related