Home > Mobile >  List comprehension to combine lists with walrus operator
List comprehension to combine lists with walrus operator

Time:12-06

Can I write this code snippet with a list comprehension? I've tried to simplify my code as possible. I'm using the walrus operator and I want to combine the lists.

def odd_generator(num):
    if num % 2:
        return [[num], [-num]]
def test():
    result = []
    for p in range(5):
        if res := odd_generator(p):
            result  = res
    return result

print(test())

Ouput:

[[1], [-1], [3], [-3]]

note: upon further inspection, it seems like I could just return an empty list in my original code and forget about the walrus operator, incorperating the old answer of Marat:

def odd_generator(num):
    if num % 2:
        return [[num], [-num]]
    return []
def test():
    return sum((odd_generator(p) for p in range(5)), start=[])
print(test())

Doing some benchmarks to see how well they scale:

from itertools import chain
from timeit import timeit
def odd_generator_1(num):
    if num % 2: return list(range(-num, num   1))
def odd_generator_2(num):
    if num % 2: return list(range(-num, num   1))
    return []
def test_0(num): return [x for i in range(num) for x in (odd_generator_1(i) or ())]
def test_1(num): return [x for lst in map(odd_generator_1, range(num)) if lst for x in lst]
def test_2(num): return sum((res for p in range(num) if (res := odd_generator_1(p))), start=[])
def test_3(num): return sum((odd_generator_1(p) or [] for p in range(num)), start=[])

def test_4(num): return list(chain.from_iterable(map(odd_generator_2, range(num))))
def test_5(num): return [x for i in range(num) for x in odd_generator_2(i)]
def test_6(num): return [x for lst in map(odd_generator_2, range(num)) for x in lst]
def test_7(num): return sum((odd_generator_2(p) for p in range(num)), start=[])
num = 1_000
count = 10
print("#0:", timeit("test_0(num)",globals=globals(), number=count))
print("#1:", timeit("test_1(num)",globals=globals(), number=count))
print("#2:", timeit("test_2(num)",globals=globals(), number=count))
print("#3:", timeit("test_3(num)",globals=globals(), number=count))
print("#=========================================================")
print("#4:", timeit("test_4(num)",globals=globals(), number=count))
print("#5:", timeit("test_5(num)",globals=globals(), number=count))
print("#6:", timeit("test_6(num)",globals=globals(), number=count))
print("#7:", timeit("test_7(num)",globals=globals(), number=count))
#0: 0.36987589998170733
#1: 0.4165434999158606
#2: 7.289839400094934
#3: 13.751488599926233
#=========================================================
#4: 0.2722389999544248
#5: 0.3475140000227839
#6: 0.4873567000031471
#7: 13.470793100073934

CodePudding user response:

You can do this with a listcomp if you really want to. The implementation would be:

def test():
    return [x for lst in map(odd_generator, range(5)) if lst for x in lst]

or if you want to avoid map (and are okay with some ugliness replacing the Nones with empty iterables):

def test():
    return [x for i in range(5) for x in (odd_generator(i) or ())]

Neither case requires the walrus, but I won't argue either of them is particularly pretty.

It would definitely make both solutions nicer if odd_generator always returned a sometimes-empty iterable, simplifying the code to one of:

def test():
    return [x for lst in map(odd_generator, range(5)) for x in lst]

def test():
    return [x for i in range(5) for x in odd_generator(i)]

or with chain.from_iterable allowing you to push it all to the C layer:

from itertools import chain

def test():
    return list(chain.from_iterable(map(odd_generator, range(5))))

There's an even faster solution using functools.reduce operator.iconcat that gets the logical effect of sum without the inefficiences, but it only applies when you're definitely making a new list, and it's more obscure; in practice, it's best to stick to chain and not preemptively convert to a list (the caller can always listify if they want to).


As a side-note: Don't use sum to combine lists; it's a form of Schlemiel the Painter's Algorithm (in that it's performing repeated not-in-place concatenation, which makes the work O(n²), where flattening one level of a nested sequence using in-place concatenation is O(n)). For a simple case, consider:

def foo(x):
    return list(range(x))

Now benchmarking with IPython's %timeit magic (on CPython x86-64 3.10.5):

>>> %timeit sum(map(foo, range(10)), [])
2.28 µs ± 27.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

>>> %timeit list(chain.from_iterable(foo, range(10)))  # Tiny bit slower for small inputs
2.54 µs ± 13.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

>>> %timeit sum(map(foo, range(100)), [])  # Larger input, but still fairly moderate size, takes >100x longer
255 µs ± 2.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

>>> %timeit list(chain.from_iterable(foo, range(100)))  # Same input processed efficiently takes less than 25x longer
61.8 µs ± 319 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

CodePudding user response:

It is possible (and arguably, better) to go without a walrus operator. Assuming res is generated by a function f(x) that can return None or empty results, we can combine short-circuit and sum of list for something like this:

sum((f(p) or [] for p in range(5)), start=[])

With a walrus operator, it will be:

sum((res for p in range(5) if (res := f(p))), start=[])

Old answer:

Generate twice as many numbers, half them (so that there are equal pairs), negate every other one:

[[-(i//2) if i&1 else i//2] for i in range(2*n)]

Alternative: sum of lists

sum(([[p], [-p]] for p in range(n)), start=[])
  • Related