Home > Enterprise >  Efficiently count all the combinations of numbers having a sum close to 0
Efficiently count all the combinations of numbers having a sum close to 0

Time:03-25

I have following pandas dataframe df

column1 column2 list_numbers          sublist_column
x        y      [10,-6,1,-4]             
a        b      [1,3,7,-2]               
p        q      [6,2,-3,-3.2]             

the sublist_column will contain the numbers from the column "list_numbers" that adds up to 0 (0.5 is a tolerance) I have written following code.

def return_list(original_lst,target_sum,tolerance):
    memo=dict()
    sublist=[]
    for i, x in enumerate(original_lst):
    
        if memo_func(original_lst, i   1, target_sum - x, memo,tolerance) > 0:
            sublist.append(x)
            target_sum -= x          
    return sublist  

def memo_func(original_lst, i, target_sum, memo,tolerance):
    
    if i >= len(original_lst):
        if target_sum <=tolerance and target_sum>=-tolerance:
            return 1
        else:
            return 0
    if (i, target_sum) not in memo:  
        c = memo_func(original_lst, i   1, target_sum, memo,tolerance)
        c  = memo_func(original_lst, i   1, target_sum - original_lst[i], memo,tolerance)
        memo[(i, target_sum)] = c  
    return memo[(i, target_sum)]    
    

Then I am using the "return_list" function on the "sublist_column" to populate the result.

target_sum = 0
tolerance=0.5

df['sublist_column']=df['list_numbers'].apply(lambda x: return_list(x,0,tolerance))

the following will be the resultant dataframe

column1 column2 list_numbers          sublist_column
x        y      [10,-6,1,-4]             [10,-6,-4]
a        b      [1,3,7,-2]               []
p        q      [6,2,-3,-3.2]            [6,-3,-3.2]  #sum is -0.2(within the tolerance)

This is giving me correct result but it's very slow(takes 2 hrs to run if i use spyder IDE), as my dataframe size has roughly 50,000 rows, and the length of some of the lists in the "list_numbers" column is more than 15. The running time is particularly getting affected when the number of elements in the lists in the "list_numbers" column is greater than 15. e.g following list is taking almost 15 minutes to process

[-1572.35,-76.16,-261.1,-7732.0,-1634.0,-52082.42,-3974.15,
-801.65,-30192.79,-671.98,-73.06,-47.72,57.96,-511.18,-391.87,-4145.0,-1008.61,
-17.53,-17.53,-1471.08,-119.26,-2269.7,-2709,-182939.59,-19.48,-516,-6875.75,-138770.16,-71.11,-295.84,-348.09,-3460.71,-704.01,-678,-632.15,-21478.76]

How can i significantly improve my running time?

CodePudding user response:

Step 1: using Numba

Based on the comments, it appear that memo_func is the main bottleneck. You can use Numba to speed up its execution. Numba compile the Python code to a native one thanks to a just-in-time (JIT) compiler. The JIT is able to perform tail-call optimizations and native function calls are significantly faster than the one of CPython. Here is an example:

import numba as nb

@nb.njit('(float64[::1], int64, float64, float64)')
def memo_func(original_arr, i, target_sum, tolerance):
    if i >= len(original_arr):
        if -tolerance <= target_sum <= tolerance:
            return 1
        return 0
    c = memo_func(original_arr, i   1, target_sum, tolerance)
    c  = memo_func(original_arr, i   1, target_sum - original_arr[i], tolerance)
    return c

@nb.njit('(float64[::1], float64, float64)')
def return_list(original_arr, target_sum, tolerance):
    sublist = []
    for i, x in enumerate(original_arr):
        if memo_func(original_arr, np.int64(i   1), target_sum - x,tolerance) > 0:
            sublist.append(x)
            target_sum -= x
    return sublist

Using memorization does not seems to speed up the result and this is a bit cumbersome to implement in Numba. In fact, there are much better ways to improve the algorithm.

Note that you need to convert the lists in Numpy array before calling the functions:

lst = [-850.85,-856.05,-734.09,5549.63,77.59,-39.73,23.63,13.93,-6455.54,-417.07,176.72,-570.41,3621.89,-233.47,-471.54,-30.33,-941.49,-1014.6,1614.5]
result = return_list(np.array(lst, np.float64), 0, tolerance)

Step 2: tail call optimization

Calling many function to compute the right part of the input list is not efficient. The JIT is able to reduce the number of all but it is not able to completely remove them. You can unroll all the call when the depth of the tail calls is big. For example, when there is 6 items to compute, you can use this following code:

if n-i == 6:
    c = 0
    s0 = target_sum
    v0, v1, v2, v3, v4, v5 = original_arr[i:]
    for s1 in (s0, s0 - v0):
        for s2 in (s1, s1 - v1):
            for s3 in (s2, s2 - v2):
                for s4 in (s3, s3 - v3):
                    for s5 in (s4, s4 - v4):
                        for s6 in (s5, s5 - v5):
                            c  = np.int64(-tolerance <= s6 <= tolerance)
    return c

This is pretty ugly but far more efficient since the JIT is able to unroll all the loop and produce a very fast code. Still, this is not enough for large lists.


Step 3: better algorithm

For large input lists, the problem is the exponential complexity of the algorithm. The thing is this problem looks really like a relaxed variant of subset-sum which is known to be NP-complete. Such class of algorithm is known to be very hard to solve. The best exact practical algorithms known so far to solve NP-complete problem are exponential. Put it shortly, this means that for any sufficiently large input, there is no known algorithm capable of finding an exact solution in a reasonable time (eg. less than the lifetime of a human).

That being said, there are heuristics and strategies to improve the complexity of the current algorithm. One efficient approach is to use a meet-in-the-middle algorithm. When applied to your use-case, the idea is to generate a large set of target sums, then sort them, and then use a binary search to find the number of matching values. This is possible here since -tolerance <= target_sum <= tolerance where target_sum = partial_sum1 partial_sum2 is equivalent to -tolerance partial_sum2 <= partial_sum1 <= tolerance partial_sum2.

The resulting code is unfortunately quite big and not trivial, but this is certainly the cost to pay for trying to solve efficiently a complex problem like this one. Here it is:

# Generate all the target sums based on in_arr and put the result in out_sum
@nb.njit('(float64[::1], float64[::1], float64)', cache=True)
def gen_all_comb(in_arr, out_sum, target_sum):
    assert in_arr.size >= 6
    if in_arr.size == 6:
        assert out_sum.size == 64
        v0, v1, v2, v3, v4, v5 = in_arr
        s0 = target_sum
        cur = 0
        for s1 in (s0, s0 - v0):
            for s2 in (s1, s1 - v1):
                for s3 in (s2, s2 - v2):
                    for s4 in (s3, s3 - v3):
                        for s5 in (s4, s4 - v4):
                            for s6 in (s5, s5 - v5):
                                out_sum[cur] = s6
                                cur  = 1
    else:
        assert out_sum.size % 2 == 0
        mid = out_sum.size // 2
        gen_all_comb(in_arr[1:], out_sum[:mid], target_sum)
        gen_all_comb(in_arr[1:], out_sum[mid:], target_sum - in_arr[0])

# Find the number of item in sorted_arr where:
# lower_bound <= item <= upper_bound
@nb.njit('(float64[::1], float64, float64)', cache=True)
def count_between(sorted_arr, lower_bound, upper_bound):
    assert lower_bound <= upper_bound
    lo_pos = np.searchsorted(sorted_arr, lower_bound, side='left')
    hi_pos = np.searchsorted(sorted_arr, upper_bound, side='right')
    return hi_pos - lo_pos

# Count all the target sums in:
# -tolerance <= all_target_sums(in_arr,sorted_target_sums)-s0 <= tolerance
@nb.njit('(float64[::1], float64[::1], float64, float64)', cache=True)
def multi_search(in_arr, sorted_target_sums, tolerance, s0):
    assert in_arr.size >= 6
    if in_arr.size == 6:
        v0, v1, v2, v3, v4, v5 = in_arr
        c = 0
        for s1 in (s0, s0   v0):
            for s2 in (s1, s1   v1):
                for s3 in (s2, s2   v2):
                    for s4 in (s3, s3   v3):
                        for s5 in (s4, s4   v4):
                            for s6 in (s5, s5   v5):
                                lo = -tolerance   s6
                                hi = tolerance   s6
                                c  = count_between(sorted_target_sums, lo, hi)
        return c
    else:
        c = multi_search(in_arr[1:], sorted_target_sums, tolerance, s0)
        c  = multi_search(in_arr[1:], sorted_target_sums, tolerance, s0   in_arr[0])
        return c

@nb.njit('(float64[::1], int64, float64, float64)', cache=True)
def memo_func(original_arr, i, target_sum, tolerance):
    n = original_arr.size
    remaining = n - i
    tail_size = min(max(remaining//2, 7), 16)

    # Tail call: for very small list (trivial case)
    if remaining <= 0:
        return np.int64(-tolerance <= target_sum <= tolerance)

    # Tail call: for big lists (better algorithm)
    elif remaining >= tail_size*2:
        partial_sums = np.empty(2**tail_size, dtype=np.float64)
        gen_all_comb(original_arr[-tail_size:], partial_sums, target_sum)
        partial_sums.sort()
        return multi_search(original_arr[-remaining:-tail_size], partial_sums, tolerance, 0.0)

    # Tail call: for medium-sized list (unrolling)
    elif remaining == 6:
        c = 0
        s0 = target_sum
        v0, v1, v2, v3, v4, v5 = original_arr[i:]
        for s1 in (s0, s0 - v0):
            for s2 in (s1, s1 - v1):
                for s3 in (s2, s2 - v2):
                    for s4 in (s3, s3 - v3):
                        for s5 in (s4, s4 - v4):
                            for s6 in (s5, s5 - v5):
                                c  = np.int64(-tolerance <= s6 <= tolerance)
        return c

    # Recursion
    c = memo_func(original_arr, i   1, target_sum, tolerance)
    c  = memo_func(original_arr, i   1, target_sum - original_arr[i], tolerance)
    return c

@nb.njit('(float64[::1], float64, float64)', cache=True)
def return_list(original_arr, target_sum, tolerance):
    sublist = []
    for i, x in enumerate(original_arr):
        if memo_func(original_arr, np.int64(i   1), target_sum - x,tolerance) > 0:
            sublist.append(x)
            target_sum -= x
    return sublist

Note that the code takes few seconds to compile since it is quite big. The cache should help not to recompile it every time.


Benchmark

Here is the tested inputs:

target_sum = 0
tolerance = 0.5

small_lst = [-850.85,-856.05,-734.09,5549.63,77.59,-39.73,23.63,13.93,-6455.54,-417.07,176.72,-570.41,3621.89,-233.47,-471.54,-30.33,-941.49,-1014.6,1614.5]
big_lst = [-1572.35,-76.16,-261.1,-7732.0,-1634.0,-52082.42,-3974.15,-801.65,-30192.79,-671.98,-73.06,-47.72,57.96,-511.18,-391.87,-4145.0,-1008.61,-17.53,-17.53,-1471.08,-119.26,-2269.7,-2709,-182939.59,-19.48,-516,-6875.75,-138770.16,-71.11,-295.84,-348.09,-3460.71,-704.01,-678,-632.15,-21478.76]

Here is the timing with the small list on my machine:

Naive python algorithm:               173.45 ms
Naive algorithm using Numba:            7.21 ms
Tail call optimization   Numba:         0.33 ms
Efficient algorithm   optim   Numba:    0.16 ms

Here is the timing with the big list on my machine:

Naive python algorithm:               >20000 s    [estimation & out of memory]
Naive algorithm using Numba:            ~900 s    [estimation]
Tail call optimization   Numba:           42.61 s
Efficient algorithm   optim   Numba:       0.05 s

Thus, the final implementation is up to ~1000 times faster on the small input and more than 400_000 times faster on the large input! It also use far less RAM so it can actually be executed on a basic PC.

It is worth noting that the execution time can be reduced even further be using multiple thread so to reach a speed up >1_000_000 though it may be slower on small inputs and it will make the code a bit complex.

  • Related