I am trying to create a function that returns a batch of data (list) every time I call it.
It should be able to repeat for any number of training steps and restart from the beginning after having iterated over the whole dataset (after each epoch).
def generate_batch(X, batch_size):
for i in range(0, len(X), batch_size):
batch = X[i:i batch_size]
yield batch
X = [
[1, 2],
[4, 0],
[5, 1],
[9, 99],
[9, 1],
[1, 1]]
for step in range(num_training_steps):
x_batch = generate_batch(X, batch_size=2)
print(list(x_batch))
when I print the output of the function, I see that it gets the whole data (X) not a batch:
[[[1, 2], [4, 0]], [[5, 1], [9, 99]], [[9, 1], [1, 1]]]
[[[1, 2], [4, 0]], [[5, 1], [9, 99]], [[9, 1], [1, 1]]]
[[[1, 2], [4, 0]], [[5, 1], [9, 99]], [[9, 1], [1, 1]]]
What is the problem? is this the right way to use yield
?
CodePudding user response:
First of all, if you want to restart from the beginning after the data is over, you will need to wrap the generator function body in an infinite loop, like this:
def generate_batch(X, batch_size):
while 1:
for i in range(0, len(X), batch_size):
batch = X[i:i batch_size]
yield batch
Then, when you do:
x_batch = generate_batch(X, batch_size=2)
Now x_batch
is a generator. You will need to iterate over it or call next()
on it to get the data one batch at a time. If you just do list(x_batch)
it will iterate and collect all the batches for you into a list. This is not what you want.
What you want is:
gen = generate_batch(X, batch_size=2)
for step in range(num_training_steps):
x_batch = next(gen)
print(x_batch)
Or alternatively, if you need a callable function:
gen = generate_batch(X, batch_size=2)
gen = gen.__next__
for step in range(num_training_steps):
x_batch = gen()
print(x_batch)
Also, you probably want to give the function a different name, like e.g. create_batch_generator()
.
CodePudding user response:
Well, you can use itertools.cycle
for this. This will keep on repeating the list like tf.data.RepeatDataset does
There is a little tweak in your source code
from itertools import cycle
def generate_batch(X, batch_size):
dataset = cycle(X)
while True:
batch = list(zip(range(batch_size), dataset))
yield list(map(lambda x: x[1], batch))
That's it. Now you can plug it into your code
X = [
[1, 2],
[4, 0],
[5, 1],
[9, 99],
[9, 1],
[1, 1]]
for step in range(20):
for batch in generate_batch(X, 2):
print(batch)
It will output like the following
[[1, 2], [4, 0]]
[[9, 99], [9, 1]]
[[1, 2], [4, 0]]
[[9, 99], [9, 1]]
[[1, 2], [4, 0]]
[[9, 99], [9, 1]]
... and so on