Home > OS >  how to expand the dimensions of a tensor in pytorch
how to expand the dimensions of a tensor in pytorch

Time:06-10

i'm a newcomer for pytorch. if i have a tensor like that:

A = torch.tensor([[1, 2, 3], [ 4, 5, 6]]),

but my question is how to get a 2 dimensions tensor like:

B =  Tensor([[[1, 2, 3],
                           [4, 5, 6]], 

                          [[1, 2, 3], 
                           [4, 5, 6]]])

CodePudding user response:

You can concatenate ...

A
tensor([[[1., 2., 3.],
         [4., 5., 6.]]])
B = torch.cat((a, a))

B
tensor([[[1., 2., 3.],
         [4., 5., 6.]],

        [[1., 2., 3.],
         [4., 5., 6.]]])

CodePudding user response:

Just use the repeat function like this

B = A.repeat(2, 1, 1)
  • Related