I have a matrix of the below format:
matrix = array([[-0.2436986 , -0.25583658, -0.16579486, ..., -0.04291612,
-0.06026303, 0.08564489],
[-0.08684622, -0.21300158, -0.04034272, ..., -0.01995692,
-0.07747065, 0.06965207],
[-0.34814256, -0.20597479, 0.06931241, ..., -0.1236965 ,
-0.1300714 , -0.110122 ],
...,
[-0.04154776, -0.07538085, 0.01860147, ..., -0.01494173,
-0.08960884, -0.21338603],
[-0.34039265, -0.24616522, 0.10838407, ..., 0.22280858,
-0.03465452, 0.04178255],
[-0.30251586, -0.23072125, -0.01975435, ..., 0.34529492,
-0.03508861, 0.00699677]], dtype=float32)
Since, I want to calculate squared distance of each element to every other, I am using the below code:
def sq_dist(a,b):
"""
Returns the squared distance between two vectors
Args:
a (ndarray (n,)): vector with n features
b (ndarray (n,)): vector with n features
Returns:
d (float) : distance
"""
d = np.sum(np.square(a - b))
return d
dim = len(matrix)
dist = np.zeros((dim,dim))
for i in range(dim):
for j in range(dim):
dist[i,j] = sq_dist(matrix[i, :], matrix[j, :])
I am getting the correct result but only for 5000 elements in 17 minutes (if I use 5000 elements instead of 100k). Since I have 100k*100k matrix, the cluster fails in 5 hours.
How to efficiently do this for a large matrix? I am using Python3.8 and Pyspark.
Output matrix should be like:
dist = array([[0. , 0.57371938, 0.78593194, ..., 0.83454031, 0.58932155,
0.76440328],
[0.57371938, 0. , 0.66285896, ..., 0.89251578, 0.76511419,
0.59261483],
[0.78593194, 0.66285896, 0. , ..., 0.60711896, 0.80852598,
0.73895919],
...,
[0.83454031, 0.89251578, 0.60711896, ..., 0. , 1.01311994,
0.84679914],
[0.58932155, 0.76511419, 0.80852598, ..., 1.01311994, 0. ,
0.5392195 ],
[0.76440328, 0.59261483, 0.73895919, ..., 0.84679914, 0.5392195 ,
0. ]])
CodePudding user response:
You can make it significantly faster by using numba:
import numpy as np
import numba as nb
@nb.njit(parallel=True)
def square_dist(matrix):
dim = len(matrix)
assert dim > 0
dist = np.zeros((dim,dim))
for i in nb.prange(dim):
for j in nb.prange(dim):
dist[i][j] = np.square(matrix[i, :] - matrix[j, :]).sum()
return dist
Test and time:
rng = np.random.default_rng()
matrix = rng.random((200, 10))
assert np.allclose(op(matrix),square_dist(matrix))
%timeit op(matrix)
%timeit square_dist(matrix)
Output:
181 ms ± 556 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
947 µs ± 43.3 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
CodePudding user response:
First, let's do a reality check. Computing N2 distances where each one takes 3N-1 operations (N subtractions, N multiplications and N-1 additions) means you have to perform about 3N3 arithmetic operations. When N is 100k, that totals to 3x1015 operations. A modern CPU with AVX-512 running at 3 GHz (3x109 Hz) can perform 3x109 [cycles/sec] x (512 / 32) [float32 entries in a vector] x 2 [vector ALUs per core] = 1011 float32 operations/second. Therefore, to compute all entries in your distance matrix it will take no less than 3x1015 / 1011 = 30000 seconds or 8 hrs and 20 mins. This is a hard lower limit, only achievable if all operations are perfectly vectorisable, which they are not, e.g. the horizontal sum after the squaring. If the CPU isn't AVX-512 capable but only supports AVX2, then the vector length is twice as small and the time goes up to about 17 hrs. All this assuming that data fits in the CPU cache - it actually doesn't and it needs proper prefetching.
First thing you can do is cut the compute time in half by noticing that dij = dji and also dii = 0:
for i in range(dim):
dist[i,i] = 0
for j in range(i 1, dim):
d[i,j] = d[j,i] = np.sum(np.square(matrix[i, :] - matrix[j, :]))
Notice the loop here runs only for i < j and the call to sq_dist
has been inlined to save you 5x109 unnecessary function calls!!
But even then, you still need more than 4 hrs on that AVX-512 CPU (more than 8 hrs with AVX2 only.)
If you really must cut down that compute time, you need to run it in parallel. With PySpark that means you have to store the vectors in a dataset, perform a self-join, and write a UDF that uses the BLAS implementation that ships with Spark (or install a native one) to compute the distance metric. Unfortunately, this is a low-level interface of Spark and it's only exposed to UDFs written in JVM languages - check this question for a Scala-based solution.