Suppose there are three sorted lists, A, B, C.
A = [1, 2, 3, 4]
B = [3, 4, 5]
C = [2, 3, 4]
I am using itertools.product
to find all possible combination whose sum is smaller than 10.
If I have three lists only, I will use the following code.
A = [1, 2, 3, 4] B = [3, 4, 5] C = [2, 3, 4]
for a in A:
for b in B:
for c in C:
if a b c < 10:
print(a, b, c)
else:
break
Here, every list is sorted, and thus I used break for efficiency.
But when I using itertools.product, then how I use the break? I mean how go to directly specific iteration (e.g., a = 3, b = 3, c = 3)?
for a, b, c in itertools.product(A, B, C):
....?
CodePudding user response:
You can try the following:
from itertools import product, dropwhile
A = [1, 2, 3, 4]
B = [3, 4, 5]
C = [2, 3, 4]
for a, b, c in dropwhile(lambda x: x != (3,3,3), product(A, B, C)):
print(a, b, c)
It gives:
3 3 3
3 3 4
3 4 2
3 4 3
3 4 4
. . .
CodePudding user response:
It isn't possible to skip iterations in itertools.product
, but given that the lists are sorted it's possible to reduce the number of iterations by using binary search and looking for items that are lower than the required difference and using memoization:
import itertools
import bisect
def bisect_fast(A, B, C, threshold):
seen_b_diff = {}
seen_c_diff = {}
for a in A:
b_diff = threshold - a
if b_diff not in seen_b_diff:
index = bisect.bisect_left(B, b_diff)
seen_b_diff[b_diff] = index
# In B we are only interested in items that are less than `b_diff`
for ind in range(seen_b_diff[b_diff]):
b = B[ind]
c_diff = threshold - (a b)
# In `C` we are only interested in items that are less than `c_diff`
if c_diff not in seen_c_diff:
index = bisect.bisect_left(C, c_diff)
seen_c_diff[c_diff] = index
for ind in range(seen_c_diff[c_diff]):
yield a, b, C[ind]
def naive(A, B, C, threshold):
for a, b, c in itertools.product(A, B, C):
if a b c < threshold:
yield a, b, c
Output
>>> from random import choice
>>> A, B, C = [sorted([choice(list(range(1000))) for _ in range(250)]) for _ in range(3)]
>>> list(naive(A, B, C, 1675)) == list(bisect_fast(A, B, C, 1675))
True
>>> %timeit list(bisect_fast(A, B, C, 1675))
1.59 s ± 32.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
>>> %timeit list(naive(A, B, C, 1675))
3.09 s ± 55.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)