Home > Software engineering >  I have a tensor of shape [5, 2, 18, 4096]. I want to stack the 0th dimension along the 2nd dimension
I have a tensor of shape [5, 2, 18, 4096]. I want to stack the 0th dimension along the 2nd dimension

Time:12-23

The shape of the tensor is [5, 2, 18, 4096]. I want to take each tensor along 0th dimension of size [2, 18, 4096] and stack it on top of another tensor which is of shape from the same tensor [2, 18, 4096] and do it for all tensors along the 0th dimension. The final tensor should be [2, 90, 4096].

CodePudding user response:

It turns out there's a very simple approach: torch.hstack always stacks along the second dimension (i.e. along axis 1). For instance, consider the following:

start = torch.arange(120).reshape([5,4,3,2])
result = torch.hstack(list(start))

The tensor start has shape 5,4,3,2, but result has shape (4,15,2), which comes from stacking 5 (4,3,2) arrays along axis 1.

Applying list to a multidimensional tensor breaks the tensor up along the main axis. In this case, list(start) is a list containing 5 (4,3,2)-shaped tensors.

CodePudding user response:

I did get to a general approach in solving this, but is there any better way to do this? Also, is it mathematically correct too?

chunks = torch.Tensor(self.buffer) #shape is [5, 2, 18, 4096]
chunks = chunks.permute(1, 0, 2, 3)
chunks = chunks.reshape(chunks.shape[0], chunks.shape[1]*chunks.shape[2], chunks.shape[-1]) 
#the resulting shape is [2, 90, 4096]
  • Related