Consider all arrays of l
non-negative integers in the range 0,...,m
. I would like to iterate (using a generator) over only those whose sum is exactly s
.
For example, take l=7, s=5, m=4
, the iteration could look like:
(0, 0, 0, 0, 0, 1, 4)
(0, 0, 0, 0, 0, 2, 3)
(0, 0, 0, 0, 0, 3, 2)
(0, 0, 0, 0, 0, 4, 1)
(0, 0, 0, 0, 1, 0, 4)
(0, 0, 0, 0, 1, 1, 3)
(0, 0, 0, 0, 1, 2, 2)
(0, 0, 0, 0, 1, 3, 1)
(0, 0, 0, 0, 1, 4, 0)
[...]
(3, 2, 0, 0, 0, 0, 0)
(4, 0, 0, 0, 0, 0, 1)
(4, 0, 0, 0, 0, 1, 0)
(4, 0, 0, 0, 1, 0, 0)
(4, 0, 0, 1, 0, 0, 0)
(4, 0, 1, 0, 0, 0, 0)
(4, 1, 0, 0, 0, 0, 0)
I don't mind which order the iteration happens in but I would like it to be efficient.
A really dumb way to reproduce the above that is far too slow for larger values of the variables is:
import itertools
s = 5
l = 7
m = 5
for arr in itertools.product(range(m), repeat=l):
if sum(arr) == s:
print(arr)
CodePudding user response:
By modyfing this answer, taking into account max_value
:
def sums(length, total_sum, max_value):
if length == 1:
yield (total_sum,)
else:
for value in range(max(0, total_sum - (length - 1) * max_value),
min(max_value, total_sum) 1):
for permutation in sums(length - 1, total_sum - value, max_value):
yield (value,) permutation
L = list(sums(7,5, 4))
print('total permutations:',len(L))
# First and last 10 of list
for i in L[:10] L[-10:]:
print(i)
total permutations: 455
(0, 0, 0, 0, 0, 1, 4)
(0, 0, 0, 0, 0, 2, 3)
(0, 0, 0, 0, 0, 3, 2)
(0, 0, 0, 0, 0, 4, 1)
(0, 0, 0, 0, 1, 0, 4)
(0, 0, 0, 0, 1, 1, 3)
(0, 0, 0, 0, 1, 2, 2)
(0, 0, 0, 0, 1, 3, 1)
(0, 0, 0, 0, 1, 4, 0)
(0, 0, 0, 0, 2, 0, 3)
(3, 1, 0, 0, 1, 0, 0)
(3, 1, 0, 1, 0, 0, 0)
(3, 1, 1, 0, 0, 0, 0)
(3, 2, 0, 0, 0, 0, 0)
(4, 0, 0, 0, 0, 0, 1)
(4, 0, 0, 0, 0, 1, 0)
(4, 0, 0, 0, 1, 0, 0)
(4, 0, 0, 1, 0, 0, 0)
(4, 0, 1, 0, 0, 0, 0)
(4, 1, 0, 0, 0, 0, 0)
CodePudding user response:
Think of the problem this way, you want to put s
balls in l
buckets with no more than m
balls in any one bucket.
Since I know how to add one ball at a time, my instinct is to solve this using recursion. The base case is putting 0
balls instead of s
and to go from one step to the next, adding 1 ball to each of the buckets that currently have less than m
balls in them.
To make sure it's actually possible to complete the recursion, we first check there is enough places to put the balls.
def balls_in_buckets(num_balls, num_buckets, max_balls):
assert num_buckets * max_balls >= num_balls, f"You can't put {num_balls} balls in {num_buckets} buckets without more than {max_balls} in a bucket."
if num_balls == 0:
yield ([0]*num_buckets).copy()
else:
seen = set()
for array in balls_in_buckets(num_balls - 1, num_buckets, max_balls):
for bucket_number in range(num_buckets):
if array[bucket_number] < max_balls:
array_copy = array.copy()
array_copy[bucket_number] = 1
if tuple(array_copy) not in seen:
seen.add(tuple(array_copy))
yield array_copy
Edit: Added code to remove duplicates
Note: Takes about 7 seconds to generate the whole sequence for l=14, s=10, m=8
. There are 1,143,870 items in the sequence. Maybe there is a faster way to generate them that avoids duplicates.
CodePudding user response:
What you are looking for are called "partitions". Unfortunately, there's some ambiguity as to whether "partitions" refers to splitting a set into partitions (e.g. [a,b,c] into [[a,b],[c]]), just the numbers characterizing size of each split (e.g. [2,1]), or the count of how many different splitting there are. The most promising search terms I found were "partition a number into k parts python", yielding Python Integer Partitioning with given k partitions and "python partition of indistinguishable items" yielding Partition N items into K bins in Python lazily . Those answers focus on partitions with at least one element, while you allow partitions to include zero elements. In addition, you seem to care about the order within a partition. For instance, you list (0, 0, 0, 0, 0, 1, 4)
, (0, 0, 0, 0, 0, 4, 1)
, and (0, 0, 0, 0, 1, 0, 4)
as distinct partitions, while traditionally those would be considered equivalent.
I think the best way is to iterate through the buckets, and for each one iterate through the possible values. I changed the parameter names; l, s, and m and not very informative names, and "sum" and "max" are built-in functions. My current version, which may need more debugging:
def get_partitions(length, total, upper_bound):
if length == 1:
if total > upper_bound:
return []
return [[total]]
if total == 0:
return [[0]*length]
return [ [n] sub_partition for
n in range(min(total, upper_bound) 1) for
sub_partition in get_partitions(
length-1, total-n, upper_bound)]
Side note: I initially read "iterate over the arrays" as meaning "go through the elements of the array". I think that the proper terminology is "iterate over the set of such arrays". When you say "iterate over x", x is being treated as the iterable, not the elements of the iterable.