I have a large matrix of N = 10000
3d vectors. To simplify, I will use a 10 x 3 matrix here as an example:
import numpy as np
A = np.array([[1.2, 2.3, 0.8],
[3.2, 2.1, 0.5],
[0.8, 4.4, 4.4],
[-0.2, -1.1, -1.1],
[2.4, 4.6, 1.6],
[0.5, 0.96, 0.33],
[1.1, 2.2, 3.3],
[-2.2, -4.41, -6.62],
[3.4, 5.5, 3.8],
[-5.1, -28., -28.1]])
I want to find all unique pairs of vectors that are nearly parallel to each other. A tolerance measurement needs to be used, and I want to get the all unique pairs of row indexes (regardless of the order). I managed to write the following code:
def all_parallel_pairs(A, tol=0.1):
res = set()
for i, v1 in enumerate(A):
for j, v2 in enumerate(A):
if i == j:
continue
norm = np.linalg.norm(np.cross(v1, v2))
if np.isclose(norm, 0., rtol=0, atol=tol):
res.add(tuple(sorted([i, j])))
return np.array(list(res))
print(all_parallel_pairs(A, tol=0.1))
out[1]: [[0 4]
[2 3]
[6 7]
[4 5]
[0 5]]
However, since I am using two for loops, it becomes slow when N
is large. I feel like there should be more efficient and Numpyic ways to do this. Any suggestions?
CodePudding user response:
Note that the function np.cross
receives an array of vectors, from the documentation:
Return the cross product of two (arrays of) vectors.
So one approach is to use numpy advance indexing to find the right vectors for which the cross product needs to be computed:
# generate the i, j indices (note that only the upper triangular matrices of indices is needed)
rows, cols = np.triu_indices(A.shape[0], 1)
# find the cross products using numpy indexing on A, and the np.cross can take array of vectors
cross = np.cross(A[rows], A[cols])
# find the values that are close to 0
arg = np.argwhere(np.isclose(0, (cross * cross).sum(axis=1) ** 0.5, rtol=0, atol=0.1))
# get the i, j indices where is 0
res = np.hstack([rows[arg], cols[arg]])
print(res)
Output
[[0 4]
[0 5]
[2 3]
[4 5]
[6 7]]
The expression:
(cross * cross).sum(axis=1) ** 0.5
is a faster replacement that applying np.linalg.norm
over an array of vectors.
CodePudding user response:
As an improvement update to Dani Masejo answer, you can use GPU_aided or TPU_aided libraries such as JAX:
from jax import jit
@jit
def test_jit():
rows, cols = np.triu_indices(A.shape[0], 1)
cross = np.cross(A[rows], A[cols])
arg = np.argwhere(np.isclose(0, (cross * cross).sum(axis=1) ** 0.5, rtol=0, atol=0.1))
res = np.hstack([rows[arg], cols[arg]])
return res
print(test_jit())
Which results will be as below by using google colab TPU runtime:
100 loops, best of 5: 12.2 ms per loop # the question code
100 loops, best of 5: 152 µs per loop # Dani Masejo code
100 loops, best of 5: 81.5 µs per loop # using jax library
Differences will be significant when data volume increase.