Home > Enterprise >  Combinations with restrictions in Python
Combinations with restrictions in Python

Time:10-25

I have a list

S = ["A", "B", "C", "D", "E", "F"]

and I want all combinations 3-by-3 without repetition but with restrictions: some elements cannot be in the same combination. Say, for example, that A, B, and C cannot be taken together: if there's A then B or C cannot belong to the combination, so in this case [A, B, D] is not a valid combination but [A, D, E] is valid.

I'm trying to code an algorithm (the list can be wider and we could have more restrictions).

What I've done for far is

S = ["A", "B", "C", "D", "E", "F"]
restricts = [
    ["A", "B", "C"],
    ["E", "F"]
]

COMBS = []

combs = list(combinations(S, 3))
# for each combination
for comb in combs:
    comb = list(comb)
    print("==> CHECKING", comb)
    valid = True
    # for each restriction
    for restrict in restricts:
        if not valid:
            break
        intersect = len(set(comb).intersection(set(restrict)))
        print("intersect", comb, restrict, "=", intersect)
        # if more than an element
        if intersect > 1:
            valid = False
    print("valid:", valid)
    if valid:
        COMBS.append(comb)
            
print("\nValid combinations:")
print(COMBS)

And it's working

Valid combinations:
[['A', 'D', 'E'], ['A', 'D', 'F'], ['B', 'D', 'E'], ['B', 'D', 'F'], ['C', 'D', 'E'], ['C', 'D', 'F']]

But I'm wondering if there's a better/faster way to do it.

CodePudding user response:

How about this? Pretty much does the same as your code (though my comments are lacking)

for comb in combinations(S, 3):
    if all(len(set(comb).intersection(r)) < 2 for r in restricts):
        COMBS.append(list(comb))

CodePudding user response:

If all you want is a shorter/nicer code then you can use this list comprehension:

combs = [
    comb for comb in combinations(S, 3)
    if all(len(set(comb).intersection(r)) < 2 for r in restricts)
]

You could also represent restrictions in a different way:

from collections import defaultdict

new_restricts = defaultdict(set)
for restrict in restricts:
    r = set(restrict)
    for x in restrict:
        r.remove(x)
        new_restricts[x].update(r)
        r.add(x)

combs = [
    comb for comb in combinations(S, 3)
    if all(len(new_restricts[x].intersection(comb)) < 2 for x in comb)
]

In this way, you need at most 3 intersections.

  • Related