Given an n-by-n matrix A, where each row of A is a permutation of [n], e.g.,
import torch
n = 100
AA = torch.rand(n, n)
A = torch.argsort(AA, dim=1)
Also given another n-by-n matrix P, we want to construct a 3D tensor Q s.t.
Q[i, j, k] = P[A[i, j], k]
Is there any efficient way in pytorch? I am aware of torch.gather but it seems hard to be directly applied here.
CodePudding user response:
You can directly use:
Q = P[A]
CodePudding user response:
Why not simply use A
as an index:
Q = P[A, :]