I have a tensor called rank which is of shape 2x3 as follows:
tensor([[ 0, 1, 2],
[ 2, 0, 1])
I want to construct a 2X3x3 matrix where the inner matrix is populated initially to all zeros using torch.zeros(2,3,3)
. I want to update the last dimension to value 1 for last dimension indices corresponding to the values in rank tensor. using indices given in rank.
Final output :
tensor([
tensor([[1, 0, 0],
[0, 1, 0],
[0, 0, 1]],
tensor([[[0, 0, 1],
[1, 0, 0],
[0, 1, 0]]
])
The value 1 is populated according to the rank given in the rank tensor. How can I do this in in pytorch and python.
CodePudding user response:
Given:
i=torch.tensor([[ 0, 1, 2],
[ 2, 0, 1]])
x=torch.tensor([[[1, 0, 0],
[0, 1, 0],
[0, 0, 1]],
[[0, 0, 1],
[1, 0, 0],
[0, 1, 0]]])
You can perform this operation using torch.scatter_
:
>>> torch.zeros(2,3,3).scatter_(2, i[:,:,None].expand_as(x), value=1)