Home > Software engineering >  Construct a 3D tensor from a 2D matrix
Construct a 3D tensor from a 2D matrix

Time:09-28

Given an n-by-n matrix A, where each row of A is a permutation of [n], e.g.,

import torch
n = 100
AA = torch.rand(n, n)
A = torch.argsort(AA, dim=1)

Also given another n-by-n matrix P, we want to construct a 3D tensor Q s.t.

Q[i, j, k] = P[A[i, j], k]

Is there any efficient way in pytorch? I am aware of torch.gather but it seems hard to be directly applied here.

CodePudding user response:

You can directly use:

Q = P[A]

CodePudding user response:

Why not simply use A as an index:

Q = P[A, :]
  • Related