Home > database >  Usage of torch.scatter() for multi-dimensional value
Usage of torch.scatter() for multi-dimensional value

Time:02-05

i have a question regarding the usage of the torch.scatter() function.

I want to construct a weights matrix weights (# [B, N, V]. B is batch size, N is number of points and V is the number of features for each point. )

Let's say i have two tensors

a = # shape [B, N, k], where B is batch size, N is number of points, k is the index number within [0,V] to select feature.
b = # shape [B, N, k], where B is batch size, N is number of points, k stores here the weights for selected feature.

I tried to use function torch.scatter(): weights.scatter_(index=a, dim=2, value=some_fix_value). By this operation i can only set one fixed value, but not the whole value tensor b, which contains all information at those location.

Can someone gives me a hint on how to do this properly?

CodePudding user response:

I believe what you are looking to do is:

weights.scatter_(dim=2, index=a, src=b)

In other words, a's last dimension is indexing b's last dimension. Which corresponds to the following operation in pseudo-code when torch.scatter's dim argument is set to 2:

out[i][j][a[i][j][k]] = b[i][j][k] 
  • Related