Home > Blockchain >  A function that returns a batch of data every time it's called
A function that returns a batch of data every time it's called

Time:05-04

I am trying to create a function that returns a batch of data (list) every time I call it.

It should be able to repeat for any number of training steps and restart from the beginning after having iterated over the whole dataset (after each epoch).

def generate_batch(X, batch_size):
    for i in range(0, len(X), batch_size):
        batch = X[i:i batch_size]
        yield batch

X = [
[1, 2],
[4, 0],
[5, 1], 
[9, 99],
[9, 1],
[1, 1]]

for step in range(num_training_steps):
    x_batch = generate_batch(X, batch_size=2)
    print(list(x_batch))

when I print the output of the function, I see that it gets the whole data (X) not a batch:

[[[1, 2], [4, 0]], [[5, 1], [9, 99]], [[9, 1], [1, 1]]]
[[[1, 2], [4, 0]], [[5, 1], [9, 99]], [[9, 1], [1, 1]]]
[[[1, 2], [4, 0]], [[5, 1], [9, 99]], [[9, 1], [1, 1]]]

What is the problem? is this the right way to use yield?

CodePudding user response:

First of all, if you want to restart from the beginning after the data is over, you will need to wrap the generator function body in an infinite loop, like this:

def generate_batch(X, batch_size):
    while 1:
        for i in range(0, len(X), batch_size):
            batch = X[i:i batch_size]
            yield batch

Then, when you do:

x_batch = generate_batch(X, batch_size=2)

Now x_batch is a generator. You will need to iterate over it or call next() on it to get the data one batch at a time. If you just do list(x_batch) it will iterate and collect all the batches for you into a list. This is not what you want.

What you want is:

gen = generate_batch(X, batch_size=2)

for step in range(num_training_steps):
    x_batch = next(gen)
    print(x_batch)

Or alternatively, if you need a callable function:

gen = generate_batch(X, batch_size=2)
gen = gen.__next__

for step in range(num_training_steps):
    x_batch = gen()
    print(x_batch)

Also, you probably want to give the function a different name, like e.g. create_batch_generator().

CodePudding user response:

Well, you can use itertools.cycle for this. This will keep on repeating the list like tf.data.RepeatDataset does

There is a little tweak in your source code

from itertools import cycle

def generate_batch(X, batch_size):
    dataset = cycle(X)
    
    while True:
        batch = list(zip(range(batch_size), dataset))
        yield list(map(lambda x: x[1], batch))

That's it. Now you can plug it into your code

X = [
[1, 2],
[4, 0],
[5, 1], 
[9, 99],
[9, 1],
[1, 1]]


for step in range(20):
    for batch in generate_batch(X, 2):
        print(batch)

It will output like the following

[[1, 2], [4, 0]]
[[9, 99], [9, 1]]
[[1, 2], [4, 0]]
[[9, 99], [9, 1]]
[[1, 2], [4, 0]]
[[9, 99], [9, 1]]
... and so on
  • Related