Home > database >  Batch Generator function returns only last element when converted to list
Batch Generator function returns only last element when converted to list

Time:11-20

I have the following code to generate batches from a list of items:

def batch_generator(items, batch_size):
    count = 1
    chunk = []
    
    for item in items:
        if count % batch_size:
            chunk.append(item)
        else:
            chunk.append(item)
            yield chunk
            chunk.clear()
        count  = 1
    
    if len(chunk):
        yield chunk

Iterating one-by-one yields expected results:

for x in batch_generator(range(17), 5):
    print(x)
# [0, 1, 2, 3, 4]
# [5, 6, 7, 8, 9]
# [10, 11, 12, 13, 14]
# [15, 16]

However when I convert generator to a list directly, only the last element is returned, multiple times!

list(batch_generator(range(17), 5))
# [[15, 16], [15, 16], [15, 16], [15, 16]]

Whereas a simple generator converted to list works just fine:

list(([i,i*2,i*3] for i in range(5)))
# [[0, 0, 0], [1, 2, 3], [2, 4, 6], [3, 6, 9], [4, 8, 12]]

Why is this happening?

CodePudding user response:

chunk.clear() is the problem here. At the end of the day the list returned is the same list returned multiple times.

replace chunk.clear() with chunk = []. That way chunk will be different instances of a list:

CodePudding user response:

You can yield a copy of chunk with chunk[:] or list(chunk). Instead of just yield chunk.

  • Related