i have a question regarding the efficient operation of the pytorch tensor multidimensional selection.
Assuming i have a tensor a, with
# B=2, V=20000, d=64
a = torch.rand(B, V, d)
and a tensor b, with
# B=2, N=30000, k=10; k is the index inside of [0, V]
b = torch.randint(0, V, (B, N, k))
The target is to construct a selected tensor from a, namely
help_1 = a[:, None, :, :].repeat(1, N, 1, 1) # [B, N, V, d]
help_2 = b[:, :, :, None].expand(-1,-1,-1,d) # [B, N, k, d]
c = torch.gather(help_1, dim=2, index=help_2)
this operation can indeed output the desired results, but is not very efficient since i created a very large help_1 matrix, which has size [2, 30000, 20000, 64]. I wonder if anyone has idea about doing this without creating such a large helper tensor for selection? Thank you!
CodePudding user response:
You could using broadcasting with the indexing to save memory. Something like the following would work.
idx0 = torch.arange(B, device=b.device).reshape(-1, 1, 1, 1) # [B, 1, 1, 1]
idx1 = b[..., None] # [B, N, k, 1]
idx2 = torch.arange(d, device=b.device).reshape(1, 1, 1, -1) # [1, 1, 1, d]
c = a[idx0, idx1, idx2] # [B, N, k, d]