Home > Enterprise >  How to invert the permutations in each row of a 2D Pytorch tensor without using for loop?
How to invert the permutations in each row of a 2D Pytorch tensor without using for loop?

Time:12-31

Suppose I have a 2D tensor that contains a permutation in each row, for example

a = torch.tensor([[4, 3, 5, 0, 2, 1],
                  [5, 0, 3, 4, 2, 1],
                  [3, 1, 0, 2, 4, 5],
                  [5, 0, 4, 3, 2, 1],
                  [2, 4, 0, 1, 5, 3]])

I want to invert all the permutations in tensor a to get a tensor b. For example, after doing this operation on the above tensor a my desired output should be

>>> b
tensor([[3, 5, 4, 1, 0, 2],
        [1, 5, 4, 2, 3, 0],
        [2, 1, 3, 0, 4, 5],
        [1, 5, 4, 3, 2, 0],
        [2, 3, 0, 5, 1, 4]])

I tried to search online and found out this answer. How should I generalize the approach mentioned in that answer to invert all the permutations without using a for loop?

CodePudding user response:

It can be done using torch.Tensor.scatter_

b = torch.zeros_like(a)
b.scatter_(1, a, torch.arange(a.size(1)).expand(a.size(0),-1))
  • Related