>>> n = 3
>>> x = range(n ** 2),
>>> xn = list(zip(*[iter(x)] * n))
In PEP 618, the author gives this example of how zip
can be used to chunk data into equal sized groups.
How does it work?
I think that it relies on an implementation detail of zip
such that if it takes the first element of each of the elements of the list [iter(x)] * n
that equates to the first n
elements because of the changing state of iter(x)
as each of the elements are taken.
This is because the following code replicates the above behavior:
n = 3
x = range(n ** 2)
xn = [iter(x)] * n
res = []
while True:
try:
col = []
for element in xn:
col.append(next(element))
res.append(col)
except:
break
However, I would like to make sure that this is indeed the case and that this is a reliable behavior that can be used to chunk elements of an iterable.
CodePudding user response:
It's not an implementation of zip
. It's how iterables work in Python - they always "consume" and move forward.
eg:
whatever = iter([1, 2, 3])
next(whatever)
# 1
next(whatever)
# 2
What zip
does is "advance" each object it's provided with and given the example you've provided [iter(x)] * n
... this becomes basically zip(whatever, whatever, whatever)
Since zip
works in sequence - it takes the first next
from whatever
- then the next
from whatever
which has already moved on from the first next
, so it's the value of 2
. Which means the next one is 3
. etc...
It's behaviour by design and the language guarantees it.
CodePudding user response:
It's not really specific to zip
, but you basically have that right. In effect, it's zipping 3 references to the same iterator, causing it to round-robin between them. During each iteration, one more element is consumed from the iterator.
Effectively, it's the same as doing this:
>>> n = 3
>>> x = range(n ** 2)
>>> a = b = c = iter(x)
>>> list(zip(a, b, c))
[(0, 1, 2), (3, 4, 5), (6, 7, 8)]
Note that it only produces equal sized groups and may drop elements (that part is a characteristic of zip
, because it's limited by the smallest iterable, though you could use itertools.zip_longest
if you want):
>>> n = 4
>>> x = range(n ** 2)
>>> a = b = c = iter(x)
>>> list(zip(a, b, c))
[(0, 1, 2), (3, 4, 5), (6, 7, 8), (9, 10, 11), (12, 13, 14)]