Home > OS >  Pytorch tensor multi-dimensional selection
Pytorch tensor multi-dimensional selection

Time:02-05

i have a question regarding the efficient operation of the pytorch tensor multidimensional selection.

Assuming i have a tensor a, with

# B=2, V=20000, d=64
a = torch.rand(B, V, d)

and a tensor b, with

# B=2, N=30000, k=10; k is the index inside of [0, V]
b = torch.randint(0, V, (B, N, k))

The target is to construct a selected tensor from a, namely

help_1 = a[:, None, :, :].repeat(1, N, 1, 1) # [B, N, V, d]
help_2 = b[:, :, :, None].expand(-1,-1,-1,d) # [B, N, k, d]
c = torch.gather(help_1, dim=2, index=help_2)

this operation can indeed output the desired results, but is not very efficient since i created a very large help_1 matrix, which has size [2, 30000, 20000, 64]. I wonder if anyone has idea about doing this without creating such a large helper tensor for selection? Thank you!

CodePudding user response:

You could using broadcasting with the indexing to save memory. Something like the following would work.

idx0 = torch.arange(B, device=b.device).reshape(-1, 1, 1, 1)  # [B, 1, 1, 1]
idx1 = b[..., None]                                           # [B, N, k, 1]
idx2 = torch.arange(d, device=b.device).reshape(1, 1, 1, -1)  # [1, 1, 1, d]
c = a[idx0, idx1, idx2]  # [B, N, k, d]
  • Related