Home > Software engineering >  Concatenate N pytorch tensors (of the same shape) generated from within loop
Concatenate N pytorch tensors (of the same shape) generated from within loop

Time:06-02

Tensors of the same shape are being returned from within a loop and I want to concatenate them succinctly and as pythonically / pytorchly as possible.

Current solution:

import torch

for object_id in object_ids:
    
    dataset = Dataset(object_id)

    image_tensor = dataset.get_random_image_tensor()

    if 'concatenated_image_tensors' in locals():
        concatenated_image_tensors = torch.cat((merged_image_tensors, image_tensor))
    else:
        concatenated_image_tensors = image_tensor

Is there a better way?

CodePudding user response:

A good approach is to first append to a python list, then concatenate at the end the whole list. Otherwise you'll end up moving data around in memory each time the torch.cat is called.

all_img = []
for object_id in object_ids:
    dataset = Dataset(object_id)
    image_tensor = dataset.get_random_image_tensor()
    all_img.append(image_tensor)

all_img = torch.cat(all_img)
  • Related