Tensors of the same shape are being returned from within a loop and I want to concatenate them succinctly and as pythonically / pytorchly as possible.
Current solution:
import torch
for object_id in object_ids:
dataset = Dataset(object_id)
image_tensor = dataset.get_random_image_tensor()
if 'concatenated_image_tensors' in locals():
concatenated_image_tensors = torch.cat((merged_image_tensors, image_tensor))
else:
concatenated_image_tensors = image_tensor
Is there a better way?
CodePudding user response:
A good approach is to first append to a python list, then concatenate at the end the whole list. Otherwise you'll end up moving data around in memory each time the torch.cat
is called.
all_img = []
for object_id in object_ids:
dataset = Dataset(object_id)
image_tensor = dataset.get_random_image_tensor()
all_img.append(image_tensor)
all_img = torch.cat(all_img)