I search a way to find all the vector from a np.meshgrid(xrange, xrange, xrange)
that are related by k = -k.
For the moment I do that :
@numba.njit
def find_pairs(array):
boolean = np.ones(len(array), dtype=np.bool_)
pairs = []
idx = [i for i in range(len(array))]
while len(idx) > 1:
e1 = idx[0]
for e2 in idx:
if (array[e1] == -array[e2]).all():
boolean[e2] = False
pairs.append([e1, e2])
idx.remove(e1)
if e2 != e1:
idx.remove(e2)
break
return boolean, pairs
# Give array of 3D vectors
krange = np.fft.fftfreq(N)
comb_array = np.array(np.meshgrid(krange, krange, krange)).T.reshape(-1, 3)
# Take idx of the pairs k, -k vector and boolean selection that give position of -k vectors
boolean, pairs = find_pairs(array)
It works but the execution time grow rapidly with N...
Maybe someone has already deal with that?
CodePudding user response:
The main problem is that comb_array
has a shape of (R, 3)
where R = N**3
and the nested loop in find_pairs
runs at least in quadratic time since idx.remove
runs in linear time and is called in the for loop. Moreover, there are cases where the for loop does not change the size of idx
and the loop appear to run forever (eg. with N=4
).
One solution to solve this problem in O(R log R)
is to sort the array and then check for opposite values in linear time:
import numpy as np
import numba as nb
# Give array of 3D vectors
krange = np.fft.fftfreq(N)
comb_array = np.array(np.meshgrid(krange, krange, krange)).T.reshape(-1, 3)
# Sorting
packed = comb_array.view([('x', 'f8'), ('y', 'f8'), ('z', 'f8')])
idx = np.argsort(packed, axis=0).ravel()
sorted_comb = comb_array[idx]
# Find pairs
@nb.njit
def findPairs(sorted_comb, idx):
n = idx.size
boolean = np.zeros(n, dtype=np.bool_)
pairs = []
cur = n-1
for i in range(n):
while cur >= i:
if np.all(sorted_comb[i] == -sorted_comb[cur]):
boolean[idx[i]] = True
pairs.append([idx[i], idx[cur]])
cur -= 1
break
cur -= 1
return boolean, pairs
findPairs(sorted_comb, idx)
Note that the algorithm assume that for each row, there are only up to one valid matching pair. If there are several equal rows, they are paired 2 by two. If your goal is to extract all the combination of equal rows in this case, then please note that the output will grow exponentially (which is not reasonable IMHO).
This solution is pretty fast even for N = 100
. Most of the time is spent in the sort that is not very efficient (unfortunately Numpy does not provide a way to do a lexicographic argsort of the row efficiently yet though this operation is fundamentally expensive).