i have a question regarding the usage of the torch.scatter() function.
I want to construct a weights matrix weights (# [B, N, V]. B is batch size, N is number of points and V is the number of features for each point. )
Let's say i have two tensors
a = # shape [B, N, k], where B is batch size, N is number of points, k is the index number within [0,V] to select feature.
b = # shape [B, N, k], where B is batch size, N is number of points, k stores here the weights for selected feature.
I tried to use function torch.scatter(): weights.scatter_(index=a, dim=2, value=some_fix_value). By this operation i can only set one fixed value, but not the whole value tensor b, which contains all information at those location.
Can someone gives me a hint on how to do this properly?
CodePudding user response:
I believe what you are looking to do is:
weights.scatter_(dim=2, index=a, src=b)
In other words, a
's last dimension is indexing b
's last dimension. Which corresponds to the following operation in pseudo-code when torch.scatter
's dim
argument is set to 2
:
out[i][j][a[i][j][k]] = b[i][j][k]