Suppose I have a 2D tensor that contains a permutation in each row, for example
a = torch.tensor([[4, 3, 5, 0, 2, 1],
[5, 0, 3, 4, 2, 1],
[3, 1, 0, 2, 4, 5],
[5, 0, 4, 3, 2, 1],
[2, 4, 0, 1, 5, 3]])
I want to invert all the permutations in tensor a
to get a tensor b
. For example, after doing this operation on the above tensor a
my desired output should be
>>> b
tensor([[3, 5, 4, 1, 0, 2],
[1, 5, 4, 2, 3, 0],
[2, 1, 3, 0, 4, 5],
[1, 5, 4, 3, 2, 0],
[2, 3, 0, 5, 1, 4]])
I tried to search online and found out this answer. How should I generalize the approach mentioned in that answer to invert all the permutations without using a for loop?
CodePudding user response:
It can be done using torch.Tensor.scatter_
b = torch.zeros_like(a)
b.scatter_(1, a, torch.arange(a.size(1)).expand(a.size(0),-1))