Home > Back-end >  Find pairs of array such as array_1 = -array_2
Find pairs of array such as array_1 = -array_2

Time:04-17

I search a way to find all the vector from a np.meshgrid(xrange, xrange, xrange) that are related by k = -k.

For the moment I do that :


@numba.njit
def find_pairs(array):
    boolean = np.ones(len(array), dtype=np.bool_)
    pairs = []
    idx = [i for i in range(len(array))]
    while len(idx) > 1:
        e1 = idx[0]
        for e2 in idx:
            if (array[e1] == -array[e2]).all():
                boolean[e2] = False
                pairs.append([e1, e2])
                idx.remove(e1)
                if e2 != e1:
                    idx.remove(e2)
                break 
    return boolean, pairs

# Give array of 3D vectors
krange = np.fft.fftfreq(N)
comb_array = np.array(np.meshgrid(krange, krange, krange)).T.reshape(-1, 3)

# Take idx of the pairs k, -k vector and boolean selection that give position of -k vectors
boolean, pairs = find_pairs(array)

It works but the execution time grow rapidly with N...

Maybe someone has already deal with that?

CodePudding user response:

The main problem is that comb_array has a shape of (R, 3) where R = N**3 and the nested loop in find_pairs runs at least in quadratic time since idx.remove runs in linear time and is called in the for loop. Moreover, there are cases where the for loop does not change the size of idx and the loop appear to run forever (eg. with N=4).

One solution to solve this problem in O(R log R) is to sort the array and then check for opposite values in linear time:

import numpy as np
import numba as nb

# Give array of 3D vectors
krange = np.fft.fftfreq(N)
comb_array = np.array(np.meshgrid(krange, krange, krange)).T.reshape(-1, 3)

# Sorting
packed = comb_array.view([('x', 'f8'), ('y', 'f8'), ('z', 'f8')])
idx = np.argsort(packed, axis=0).ravel()
sorted_comb = comb_array[idx]

# Find pairs
@nb.njit
def findPairs(sorted_comb, idx):
    n = idx.size
    boolean = np.zeros(n, dtype=np.bool_)
    pairs = []
    cur = n-1
    for i in range(n):
        while cur >= i:
            if np.all(sorted_comb[i] == -sorted_comb[cur]):
                boolean[idx[i]] = True
                pairs.append([idx[i], idx[cur]])
                cur -= 1
                break
            cur -= 1
    return boolean, pairs

findPairs(sorted_comb, idx)

Note that the algorithm assume that for each row, there are only up to one valid matching pair. If there are several equal rows, they are paired 2 by two. If your goal is to extract all the combination of equal rows in this case, then please note that the output will grow exponentially (which is not reasonable IMHO).

This solution is pretty fast even for N = 100. Most of the time is spent in the sort that is not very efficient (unfortunately Numpy does not provide a way to do a lexicographic argsort of the row efficiently yet though this operation is fundamentally expensive).

  • Related