Home > Back-end >  Python recursive generator breaks when using list() and append() keywords
Python recursive generator breaks when using list() and append() keywords

Time:11-21

I have only recently learned about coroutines using generators and tried to implement the concept in the following recursive function:

def _recursive_nWay_generator(input: list, output={}):
    '''
    Helper function; used to generate parameter-value pairs
    to submit to the model for the simulation.

    Parameters
    ----------
    input : list of tuple
        every tuple of the list must be of the form:
        ``('name_of_parameter', iterable_of_values)``

    output : list, optional
        parameter used for recursion; allows for list building
        across subgenerators

    Returns
    -------
    Generator :
        Specifications used for simulation setup of the form:
        ``{'par1': val1, ...}``
    '''
    # exit condition
    if len(input) == 0:
        yield output
    # recursive loop
    else:
        curr = input[0]
        par_name = curr[0]
        for par_value in curr[1]:
            output[par_name] = par_value
            # coroutines for the win!
            yield from _recursive_nWay_generator(input[1:], output=output)

Function somewhat works as intended:

testlist = [('a', (1, 2, 3)), ('b', (4, 5, 6)), ('c', (7, 8))]
for a in _recursive_nWay_generator(testlist):
    print(a)

Output:

{'a': 1, 'b': 4, 'c': 7}
{'a': 1, 'b': 4, 'c': 8}
{'a': 1, 'b': 5, 'c': 7}
{'a': 1, 'b': 5, 'c': 8}
{'a': 1, 'b': 6, 'c': 7}
{'a': 1, 'b': 6, 'c': 8}
{'a': 2, 'b': 4, 'c': 7}
{'a': 2, 'b': 4, 'c': 8}
{'a': 2, 'b': 5, 'c': 7}
{'a': 2, 'b': 5, 'c': 8}
{'a': 2, 'b': 6, 'c': 7}
{'a': 2, 'b': 6, 'c': 8}
{'a': 3, 'b': 4, 'c': 7}
{'a': 3, 'b': 4, 'c': 8}
{'a': 3, 'b': 5, 'c': 7}
{'a': 3, 'b': 5, 'c': 8}
{'a': 3, 'b': 6, 'c': 7}
{'a': 3, 'b': 6, 'c': 8}

However, it breaks when I try to append to an existing list or construct a new one:

gen = _recursive_nWay_generator(testlist)
print(list(gen))

Output:

[{'a': 3, 'b': 6, 'c': 8}, {'a': 3, 'b': 6, 'c': 8}, {'a': 3, 'b': 6, 'c': 8}, {'a': 3, 'b': 6, 'c': 8}, {'a': 3, 'b': 6, 'c': 8}, {'a': 3, 'b': 6, 'c': 8}, {'a': 3, 'b': 6, 'c': 8}, {'a': 3, 'b': 6, 'c': 8}, {'a': 3, 'b': 6, 'c': 8}, {'a': 3, 'b': 6, 'c': 8}, {'a': 3, 'b': 6, 'c': 8}, {'a': 3, 'b': 6, 'c': 8}, {'a': 3, 'b': 6, 'c': 8}, {'a': 3, 'b': 6, 'c': 8}, {'a': 3, 'b': 6, 'c': 8}, {'a': 3, 'b': 6, 'c': 8}, {'a': 3, 'b': 6, 'c': 8}, {'a': 3, 'b': 6, 'c': 8}]

This question was attempting to do something close to what I have, but I'm not seeing answers that could help.

I am honestly clueless as to how to solve this, the online searches I tried gave nothing no matter how I phrase the question. If this was answered before I'll be happy to just follow the link.

CodePudding user response:

The problem with your code is reusing the same mutable output dict during the iteration and recursive calls. That is, you yield output and then later on you modify it with output[par_name] = par_value but it's the same dict in each case - so you're modifying the instance which was already returned! If you append each result into a list and then print them all at the end, you'll see that they're identical - it's the same result yielded each time.

The simplest way to "fix" your existing code is to yield copies, i.e. change the line:

yield output

into this:

yield dict(output.items())

However, this algorithm is not great, and I recommend you look for something better. Using recursion is poor choice here. I'll offer you a simple/direct way to generate the sequence more efficiently:

import itertools as it 

testlist = [('a', (1, 2, 3)), ('b', (4, 5, 6)), ('c', (7, 8))]
keys, vals = zip(*testlist)
for p in it.product(*vals):
    print(dict(zip(keys, p)))
  • Related