Home > other >  How to generate indices like [0,2,4,1,3,5] without using explicit loop for reorganizing rows of a te
How to generate indices like [0,2,4,1,3,5] without using explicit loop for reorganizing rows of a te

Time:12-28

Suppose I have a Tensor like

a = torch.tensor([[3, 1, 5, 0, 4, 2],
                  [2, 1, 3, 4, 5, 0],
                  [0, 4, 5, 1, 2, 3],
                  [3, 1, 4, 5, 0, 2],
                  [3, 5, 4, 2, 0, 1],
                  [5, 3, 0, 4, 1, 2]])

and I want to reorganize the rows of the tensor by applying the transformation a[c] where

c = torch.tensor([0,2,4,1,3,5])

to get

b = torch.tensor([[3, 1, 5, 0, 4, 2],
                  [0, 4, 5, 1, 2, 3],
                  [3, 5, 4, 2, 0, 1],
                  [2, 1, 3, 4, 5, 0],
                  [3, 1, 4, 5, 0, 2],
                  [5, 3, 0, 4, 1, 2]])

For doing it, I want to generate the tensor c so that I can do this transformation irrespective of the size of tensor a and the stepping size (which I have taken to be equal to 2 in this example for simplicity). Can anyone let me know how do I generate such a tensor for the general case without using an explicit for loop in PyTorch?

CodePudding user response:

I also came up with another solution, which solves the above problem of reorganizing the rows of tensor a to generate tensor b without generating the indices array c

step = 2
b = a.view(-1,step,a.size(-1)).transpose(0,1).reshape(-1,a.size(-1))

CodePudding user response:

Thinking for a little longer, I came up with the below solution for generation of the indices

step = 2
idx = torch.arange(0,a.size(0),step)
# idx = tensor([0, 2, 4])
idx = idx.repeat(int(a.size(0)/idx.size(0)))
# idx = tensor([0, 2, 4, 0, 2, 4])
incr = torch.arange(0,step)
# incr = tensor([0, 1])
incr = incr.repeat_interleave(int(a.size(0)/incr.size(0)))
# incr = tensor([0, 0, 0, 1, 1, 1])
c = incr   idx
# c = tensor([0, 2, 4, 1, 3, 5])

After this, the tensor c can be used to get the tensor b by using

b = a[c.long()]

CodePudding user response:

You can use torch.index_select, so:

b = torch.index_select(a, 0, c)

The explanation in the official docs is pretty clear.

  • Related