Home > Net >  How to skip some iterations when using itertools.product?
How to skip some iterations when using itertools.product?

Time:11-02

Suppose there are three sorted lists, A, B, C.

A = [1, 2, 3, 4]
B = [3, 4, 5]
C = [2, 3, 4]

I am using itertools.product to find all possible combination whose sum is smaller than 10.

If I have three lists only, I will use the following code.

A = [1, 2, 3, 4] B = [3, 4, 5] C = [2, 3, 4]

for a in A:
    for b in B:
        for c in C:
            if a   b   c < 10:
                print(a, b, c)
            else:
                break

Here, every list is sorted, and thus I used break for efficiency.

But when I using itertools.product, then how I use the break? I mean how go to directly specific iteration (e.g., a = 3, b = 3, c = 3)?

for a, b, c in itertools.product(A, B, C):
   ....?

CodePudding user response:

You can try the following:

from itertools import product, dropwhile

A = [1, 2, 3, 4] 
B = [3, 4, 5] 
C = [2, 3, 4]

for a, b, c in dropwhile(lambda x: x != (3,3,3), product(A, B, C)):
    print(a, b, c)

It gives:

3 3 3
3 3 4
3 4 2
3 4 3
3 4 4
. . .

CodePudding user response:

It isn't possible to skip iterations in itertools.product, but given that the lists are sorted it's possible to reduce the number of iterations by using binary search and looking for items that are lower than the required difference and using memoization:

import itertools
import bisect


def bisect_fast(A, B, C, threshold):
    seen_b_diff = {}
    seen_c_diff = {}

    for a in A:
        b_diff = threshold - a
        if b_diff not in seen_b_diff:
            index =  bisect.bisect_left(B, b_diff)
            seen_b_diff[b_diff] = index

        # In B we are only interested in items that are less than `b_diff`
        for ind in range(seen_b_diff[b_diff]):
            b = B[ind]
            c_diff = threshold - (a   b)
            # In `C` we are only interested in items that are less than `c_diff`
            if c_diff not in seen_c_diff:
                index = bisect.bisect_left(C, c_diff)
                seen_c_diff[c_diff] = index

            for ind in range(seen_c_diff[c_diff]):
                yield a, b, C[ind] 


def naive(A, B, C, threshold):
    for a, b, c in itertools.product(A, B, C):
        if a   b   c < threshold:
            yield a, b, c

Output

>>> from random import choice
>>> A, B, C = [sorted([choice(list(range(1000))) for _ in range(250)]) for _ in range(3)]
>>> list(naive(A, B, C, 1675)) == list(bisect_fast(A, B, C, 1675))
True
>>> %timeit list(bisect_fast(A, B, C, 1675))
1.59 s ± 32.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
>>> %timeit list(naive(A, B, C, 1675))
3.09 s ± 55.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
  • Related