I have 6 tensors of shape (batch_size, S, S, 1) and I want to combine them in one python list of size (batch_size, S*S, 6) - so every element of tensor should be inside the inner list.
Can this be achieved without using loops? What's the efficient way to solve it?
CodePudding user response:
Let batch_size=10
and S=4
for the purpose of this example:
>>> x = [torch.rand(10, 4, 4, 1) for _ in range(6)]
Indeed the first step is to concatenate the tensor on the last dimension axis=3
:
>>> y = torch.cat(x, -1)
>>> y.shape
torch.Size([10, 4, 4, 6])
Then reshape to flatten axis=1
and axis=2
, you can do so with torch.flatten
here since the two axes as adjacent:
>>> y = torch.cat(x, -1).flatten(1, 2)
>>> y.shape
torch.Size([10, 16, 6])