Home > other >  How do I implement custom iterators so that I can nest them?
How do I implement custom iterators so that I can nest them?

Time:10-05

I was just looking up some stuff about python iterators and stumbled across thisW3School iterator example example from w3school:

class MyNumbers:
  def __iter__(self):
    self.a = 1
    return self

  def __next__(self):
    if self.a <= 20:
      x = self.a
      self.a  = 1
      return x
    else:
      raise StopIteration

myclass = MyNumbers()
myiter = iter(myclass)

for x in myiter:
  print(x)

The code prints the numbers from 1 to 20 to the console.

I was wondering if the code works for nested iterators, as it uses an attribute for keeping track of the number of iterations that have already passed. So I set up a small example (with only three iterations instead of 20) and indeed, it does not work as expected:

class MyNumbers:
  def __iter__(self):
    self.a = 1
    return self

  def __next__(self):
    if self.a <= 3:
      x = self.a
      self.a  = 1
      return x
    else:
      raise StopIteration

myclass = MyNumbers()
myiter = iter(myclass)

for x in myiter:
  for y in myiter:
    print('outer value: '   str(x))
    print('inner value: '   str(y))

print("*"*50)

for x in myclass:
  for y in myclass:
    print('outer value: '   str(x))
    print('inner value: '   str(y))

print("*"*50)

for x in iter(myclass):
  for y in iter(myclass):
    print('outer value: '   str(x))
    print('inner value: '   str(y))

Output:

outer value: 1
inner value: 1
outer value: 1
inner value: 2
outer value: 1
inner value: 3
**************************************************
outer value: 1
inner value: 1
outer value: 1
inner value: 2
outer value: 1
inner value: 3
**************************************************
outer value: 1
inner value: 1
outer value: 1
inner value: 2
outer value: 1
inner value: 3

I can see how these results occur; the attribute that keeps track of the number of iteration is increased by the inner iterator which immediately fails the self.a <= 3 check for the outer iterator once the inner iterator is done. I then tried a similar example with lists and they behaved differently:

a = [1, 2, 3]
for x in a:
  for y in a:
    print('outer value: '   str(x))
    print('inner value: '   str(y))

Output:

outer value: 1
inner value: 1
outer value: 1
inner value: 2
outer value: 1
inner value: 3
outer value: 2
inner value: 1
outer value: 2
inner value: 2
outer value: 2
inner value: 3
outer value: 3
inner value: 1
outer value: 3
inner value: 2
outer value: 3
inner value: 3

This version works as one would expect from nested iterators. My question is now: how could I rewrite the given example so that it works as intended? I thought about a factory that generates iterable objects but that seems really complicated (and I'm not sure if it would work either). Does anybody know an easy/easier way?

CodePudding user response:

A quick and dirty example to show how this might be achieved:

class MyList:

    def __init__(self, ls):
        self.ls = ls


    def __iter__(self):
        class MyListIter:
            def __init__(self, ls):
                self.ls = ls.copy()
                self.n = -1

            def __next__(self):
                self.n  = 1
                if self.n >= len(self.ls):
                    raise StopIteration
                return self.ls[self.n]

        return MyListIter(self.ls)


x = MyList([1, 2, 4, 8])

for i in x:
    for j in x:
        print(i, j)

Outputs:

1 1
1 2
1 4
1 8
2 1
2 2
2 4
2 8
4 1
4 2
4 4
4 8
8 1
8 2
8 4
8 8

The trick is that we need to keep track of the iterations separately for each, so I've added another object here to take care of this.

There are a few other ways this can be done as well, but this is probably the simplest pattern.

  • Related