Home > Back-end >  Count all unique quadruples that sum to a given value - is N^3 complexity algorithm known?
Count all unique quadruples that sum to a given value - is N^3 complexity algorithm known?

Time:11-01

I am supposed to solve this problem in as low time complexity as possible, but let me be more specific.

You are given a sorted array of integers that contains duplicates.

Unique quadruple is a set of four indexes. Elements from the array under those indexes have to sum to a given value X. For example:

  1. Given an array [10, 20, 30, 40] and X = 100, there is only one quadruple: (0, 1, 2, 3).

  2. Given an array [0, 0, 0, 0, 0] and X = 0, there are 5 quadruples: (0, 1, 2, 3), (0, 1, 2, 4), (0, 1, 3, 4), (0, 2, 3, 4), (1, 2, 3, 4).

On the internet there is plenty of N^3 solutions but those are for unique quadruples in terms on values, not indexes. In those solutions, example no. 1 would still give only one quadruple: (10, 20, 30, 40), but example no. 2 gives only one quadruple (0, 0, 0, 0), not five of them.

I couldn't find O(N^3) solution that would solve my problem instead of the other one. I can easily write a program that solves it in O(N^3logN) time. I also heard that the lower complexity bound for this problem is allegedly not known. Is there O(N^3) solution known though?

Solutions known to me:

  1. Obvious naive approach O(N^4):
int solution(int arr[], int arrSize, int X){
    int counter = 0;
    for(int i=0; i<arrSize-3;   i)
        for(int j=i 1; j<arrSize-2;   j)
            for(int k=j 1; k<arrSize-1;   k)
                for(int l=k 1; l<arrSize;   l)
                    if(arr[i] arr[j] arr[k] arr[l] == X)   counter;
    return counter;
}
  1. Approach using triplets and binary search O(N^3logN):
int solution(int arr[], int arrSize, int X){
    int counter = 0;
    for(int i=0; i<arrSize-3;   i)
        for(int j=i 1; j<arrSize-2;   j)
            for(int k=j 1; k<arrSize-1;   k){
                int subX = X - arr[i] - arr[j] - arr[k];
                int first = binFirst(subX, arr, k 1, arrSize);
                //binary search that returns position of first
                //occurence of subX in arr in range [k 1, arrSize)
                //or -1 if not found
                int last = binLast(subX, arr, k 1, arrSize);
                //binary search that returns position of last
                //occurence of subX in arr in range [k 1, arrSize)
                //or -1 if not found
                if(first != -1) counter  = last - first   1;
    return counter;

Naturally above algorithm could be improved by counting all duplicates of arr[i], arr[j], arr[k] but as far as I can tell it does not lower the actual O(N^3logN) complexity.

CodePudding user response:

O(n²) in Python, inspired by גלעד ברקן's answer:

from itertools import combinations
from collections import Counter

def solution(arr, X):
    cd = Counter(map(sum, combinations(arr, 2)))
    count = 0
    for i, b in enumerate(arr):
        for d in arr[i 1:]:
            cd[b d] -= 1
        for a in arr[:i]:
            count  = cd[X - (a b)]
    return count

Call the quadruples (a,b,c,d). We focus on the second element, b. For each possible b, we add each possible a (elements left of b), and look up how many pairs (c,d) (elements right of b) complete the sum a b c d = X, i.e., sum to X - (a b). For that lookup, we have a hash map cd that maps sums of pairs to counts of pairs. Initially, that's all pairs of the whole arr, but for each b we consider, remove its contributions to the map.

C version, where a/b/c/d are indexes instead of elements:

int solution(int arr[], int n, int X){
  std::unordered_map<int, int> cd;
  for (int c=0; c<n; c  )
    for (int d=c 1; d<n; d  )
      cd[arr[c] arr[d]]  ;
  int count = 0;
  for (int b=0; b<n; b  ) {
    for (int d=b 1; d<n; d  )
      cd[arr[b] arr[d]]--;
    for (int a=0; a<b; a  )
      count  = cd[X - (arr[a] arr[b])];
  }
  return count;
}

Python code with testing (Try it online!):

from itertools import combinations
from collections import Counter

def solution(arr, X):
    cd = Counter(map(sum, combinations(arr, 2)))
    count = 0
    for i, b in enumerate(arr):
        for d in arr[i 1:]:
            cd[b d] -= 1
        for a in arr[:i]:
            count  = cd[X - (a b)]
    return count

import random
from operator import countOf

def naive(arr, X):
    sums = map(sum, combinations(arr, 4))
    return countOf(sums, X)

arr = random.choices(range(100), k=100)
print(naive(arr, 200))
print(solution(arr, 200))

C code with testing.

CodePudding user response:

We can do it in O(n^2) time and space by dynamically updating.

Start by creating the hash-map of sum to set of tuples that compose it, traversing from the left, and store for each index, the tuples it belongs to (O(n) of them), until the right two elements are left and not hashed.

Now traverse towards the left: starting with the third rightmost element, remove all the tuples the current element belongs to (O(n) of them). Then for each sum the element can create by pairing with an element on its right, add the count of tuples in the corresponding hashed sum that would complete the overall sum. Because we removed all instances where the current element was used on the left, we are guaranteed to have partitioned quadruples, where none on the right are represented in the hashed counts from the left.

CodePudding user response:

The array is sorted which means we can use binary search.

Now, if we create pairs that contain sums of pairs, for example

arr = [10, 20, 30, 40]
pairs = [10 20, 10 30, 10 40, 20 30, 20 40, 30 40]

There is a pattern, we have 3 pairs for 10 x, 2 pairs for 20 x, 1 pair for 30 x, and 0 pairs for 40 x.

 [10 20, 10 30, 10 40, 20 30, 20 40, 30 40]
# -------------------  ------------  -----

 [30, 40, 50, 50, 60, 70]
# ----------  ------  --

So, the total pairs are

3   2   1 
= sum of first (n-1) natural numbers 
= (n - 1) * (n - 1   1) / 2 
= (n - 1) * n / 2
= (n^2 - n) / 2

It looks like the whole pairs array will be sorted, but it is not true, those sub-arrays in pairs should be sorted because the initial arr is sorted. for example

arr = [10, 20, 30, 90]
pairs = [10 20, 10 30, 10 90, 20 30, 20 90, 30 90]

# Those sub-arrays are sorted
 [30, 40, 100, 50, 110, 120]
# -----------  -------  ---

Now, Let's write the pairs with origin arr indices

pairs = [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]

(0, 1) and (0, 2) are not valid quadruples because we are having 0 in both pairs So, how can we logically find valid pairs?

We only have one valid pair for (0, 1) which is (2, 3) which does not have 0 or 1

 [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]
#  x  x    x       x       x       x       ----

One fact is, We can always write quadruple in such a way that one pair is next to the other pair, for example

x = 100
arr = [10, 20, 30, 40]
pairs = [30, 40, 50, 50, 60, 70]

 [10, 20, 30, 40]
# --  ------  --
quadruple = (10   40)   (20   30)

# which can we re-written as
 [10, 20, 30, 40]
# ------  ------
quadruple = (10   20)   (30   40) = 30   70

# Which is as follows
pairs = [30, 40, 50, 50, 60, 70]
#        --                  --

So, we can do as follow to solve the problem

for pair0 in pairs:
    valid_pairs_for_pair0 = # Somehow get the valid pairs
    for pair1 in valid_pairs_for_pair0:
        if pair0   pair1 == x:
            ans  = 1

But the above solution is O(n^4) because pairs is of length (n^2 - n) / 2

We can do better as we know those sub-arrays in the pairs are sorted

arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] # n = 10
pairs = [
  (0,1),(0,2),(0,3),(0,4),(0,5),(0,6),(0,7),(0,8),(0,9),# (0,x) -> 9 pairs -> 10 - 0 - 1
  (1,2),(1,3),(1,4),(1,5),(1,6),(1,7),(1,8),(1,9),# (1,x) -> 8 pairs -> 10 - 1 - 1
  (2,3),(2,4),(2,5),(2,6),(2,7),(2,8),(2,9),# (2,x) -> 7 pairs -> 10 - 2 - 1
  (3,4),(3,5),(3,6),(3,7),(3,8),(3,9),# (3,x) -> 6 pairs -> 10 - 3 - 1
  (4,5),(4,6),(4,7),(4,8),(4,9),# (4,x) -> 5 pairs -> 10 - 4 - 1
  (5,6),(5,7),(5,8),(5,9),# (5,x) -> 4 pairs -> 10 - 5 - 1
  (6,7),(6,8),(6,9),# (6,x) -> 3 pairs -> 10 - 6 - 1
  (7,8),(7,9),# (7,x) -> 2 pairs -> 10 - 7 - 1
  (8,9),# (8,x) -> 1 pair -> 10 - 8 - 1
]

# we need to find the first valid pair and all of the pairs after that will be valid.

first valid pair index for (0, 1) => first (2,x) pair => (2,3) => pairs[9   8]
first valid pair index for (0, 2) => first (3,x) pair => (3,4) => pairs[9   8   7]
first valid pair index for (0, 3) => first (4,x) pair => (4,5) => pairs[9   8   7   6]

# There is a pattern 
pairs[9   8] => pairs[sum(9 to 1) - sum(7 to 1)]
pairs[9   8   7] => pairs[sum(9 to 1) - sum(6 to 1)]
pairs[9   8   7   6] => pairs[sum(9 to 1) - sum(5 to 1)]

# Thats how we get started and for binary search
start = firstNSum(n - 1) - firstNSum(n - i1 - 2)
end = start   n - (i1   1) - 1 # n - (i1   1) - 1 is the number of pairs for (i1,x) pairs

Now, we can solve the problem as follow

# for pair0 in pairs:
    # binary search for all valid sub-arrays of pairs for pair0

Time complexity: O(n^3.log(n)) log(n) log(n-1) ... log(1) = log(n!) = n.log(n)

Space complexity: O(n^2)

def firstNSum(n):
    return n * (n   1) // 2

def binary_search(pairs, x, start, end):
    while start < end:
        mid = (start   end) // 2
        if pairs[mid][1] < x:
            start = mid   1
        else: 
            end = mid
    return start


def count_four_pairs_with_sum(arr, n, x):
    ans = 0

    pairs = []

    for i0 in range(n - 1):
        for i1 in range(i0   1, n): 
            curr_sum = arr[i0]   arr[i1]
            pairs.append([(i0, i1), curr_sum])

    for [(i0, i1), curr_sum] in pairs:

        start = firstNSum(n - 1) - firstNSum(n - i1 - 2)
        end = start   n - (i1   1) - 1

        while start < len(pairs):
            x_start = binary_search(pairs, x - curr_sum, start, end)
            x_end = binary_search(pairs, x - curr_sum   1, start, end)

            ans  = x_end - x_start

            i1  = 1
            start  = n - i1 - 1
            end = start   n - (i1   1) - 1

    return ans



arr = [10, 20, 30, 40]
n = len(arr)
x = 100
print(count_four_pairs_with_sum(arr, n, x))

We can do better, If we store the number of pairs with sum alongside with that also storing how many pairs are from each (i,x) pair group from pairs

# loop for i0
    # loop for i1
        # ans  = valid pairs for i0 and i1, which is sum of i1 to n excluding i0 to i1

Time complexity: O(n^3)

Space complexity: O(n^3)

from collections import defaultdict

def count_four_pairs_with_sum(arr, n, x):
    ans = 0

    sum_freq = defaultdict(lambda: defaultdict(int))

    for i0 in range(n - 1):
        for i1 in range(i0   1, n): 
            curr_sum = arr[i0]   arr[i1]
            sum_freq[curr_sum][i0]  = 1

    for i0 in range(n - 1):
        for i1 in range(i0   1, n): 
            curr_sum = arr[i0]   arr[i1]
            needed_sum = x - curr_sum
            valid_needed_sum_count = sum([sum_freq[needed_sum][i] for i in range(i1 1, n)])
            ans  = valid_needed_sum_count

    return ans


arr = [0, 0, 0, 0, 0]
n = len(arr)
x = 0
print(count_four_pairs_with_sum(arr, n, x))
  • Related