I hava a list of 10,000 random sets with different lengths:
import random
random.seed(99)
lst = [set(random.sample(range(1, 10000), random.randint(1, 1000))) for _ in range(10000)]
I want to know the fastest way to check if there is any set that is a subset of another set (or equivalently if there is any set that is a superset of another set). Right now I am using the following very basic code:
def any_containment(lst):
checked_sets = []
for st in lst:
if any(st.issubset(s) for s in checked_sets):
return True
else:
checked_sets.append(st)
return False
%timeit any_containment(lst)
# 12.3 ms ± 230 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Clearly, my code is not utilizing previous information when checking containment in each iteration. Can anyone suggest the fastest way to do this?
CodePudding user response:
Seems to be faster to sort by length and then try small sets as subsets first. Times in ms from ten cases, data generated like you did but without seeding:
agree yours mine ratio result
True 2.24 2.98 0.75 True
True 146.25 3.10 47.19 True
True 121.66 2.90 41.91 True
True 0.21 2.73 0.08 True
True 37.01 2.82 13.10 True
True 5.86 3.13 1.87 True
True 54.61 3.14 17.40 True
True 0.86 2.81 0.30 True
True 182.51 3.06 59.60 True
True 192.93 2.73 70.65 True
Code (Try it online!):
import random
from timeit import default_timer as time
def original(lst):
checked_sets = []
for st in lst:
if any(st.issubset(s) for s in checked_sets):
return True
else:
checked_sets.append(st)
return False
def any_containment(lst):
remaining = sorted(lst, key=len, reverse=True)
while remaining:
s = remaining.pop()
if any(s <= t for t in remaining):
return True
return False
for _ in range(10):
lst = [set(random.sample(range(1, 10000), random.randint(1, 1000))) for _ in range(10000)]
t0 = time()
expect = original(lst)
t1 = time()
result = any_containment(lst)
t2 = time()
te = t1 - t0
tr = t2 - t1
print(result == expect, '%6.2f ' * 3 % (te*1e3, tr*1e3, te/tr), expect)
Edit: The following seems further ~20% faster:
def any_containment(lst):
sets = sorted(lst, key=len)
for i in range(1, len(sets)):
for s, t in zip(sets, sets[-i:]):
if s <= t:
return True
return False
Comparison with my old solution (Try it online!):
agree old new ratio result
True 3.13 2.46 1.27 True
True 3.36 3.31 1.02 True
True 3.10 2.49 1.24 True
True 2.72 2.43 1.12 True
True 2.86 2.35 1.21 True
True 2.65 2.47 1.07 True
True 5.24 4.29 1.22 True
True 3.01 2.35 1.28 True
True 2.72 2.28 1.19 True
True 2.80 2.45 1.14 True