I'm trying to sort C
(see image) using R
to get Sorted_C
.
c = torch.tensor([[[0, 1, 0, 0, 0], [1, 0, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1]], [[0, 0, 1, 1, 0], [1, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0]]])
r = torch.tensor([[[0, 0, 0, 7.2, 0], [0, 25.4, 0, 0, 0], [0, 0, 43.6, 0, 0], [61.8, 0, 0, 0, 0], [0, 0, 0, 0, 80]], [[0, 0, 0, 0, 98.2], [116.4, 0, 0, 0, 0], [0, 134.6, 0, 0, 0], [0, 0, 152.8, 0, 0], [0, 0, 0, 169.2, 0]]])
# this is what I need
sorted_c = torch.tensor([[[0, 1, 0, 0, 0], [0, 0, 1, 1, 0], [1, 0, 0, 0, 0], [0, 0, 0, 0, 1]], [[0, 0, 0, 1, 1], [0, 1, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 1, 0, 1]]])
How do I do this efficiently?
correction: expected ordering for P6 to P10
should be: P10 -> P6 -> P7 -> P8 -> P9
CodePudding user response:
Minor note: it seems your example r
doesn't quite match the values in your drawing. Also, it seems the result for p6
through p10
should be p10 < p6 < p7 < p8 < p9
.
When you hear "advanced indexing", you should think torch.gather
. That is: the resulting tensor comes from the indexing of c
column-wise with some kind of tensor we will extract from r
.
First we can sum and sort r
to get the indices of the columns:
>>> idx = r.sum(1).argsort(1)
tensor([[3, 1, 2, 0, 4],
[4, 0, 1, 2, 3]])
Then we can apply torch.Tensor.gather
indexing c
column-wise using the column indices contained in idx
i.e. dim=2
is the one varying based on values in idx
. Explicitly the resulting tensor out
is constructed such that:
out[i][j][k] = c[i][j][idx[i][j][k]]
Keep in mind both the index tensor (idx
) and the value tensor (c
) must have same dimension sizes except for the one that we're indexing on, here dim=2
. Therefore, we need to expand idx
such that it has the shape of c
. We can do so with None
-indexing and using expand
or expand_as
:
>>> idx[:,None].expand_as(c)
tensor([[[3, 1, 2, 0, 4],
[3, 1, 2, 0, 4],
[3, 1, 2, 0, 4],
[3, 1, 2, 0, 4]],
[[4, 0, 1, 2, 3],
[4, 0, 1, 2, 3],
[4, 0, 1, 2, 3],
[4, 0, 1, 2, 3]]])
Notice the duplicated values row-wise (fiy: they're not copies, expand
is makes a view not a copy!)
Finally, we can gather the values in c
to get the desired result:
>>> c.gather(2, idx[:,None].expand_as(c))
tensor([[[0, 1, 0, 0, 0],
[0, 0, 1, 1, 0],
[1, 0, 0, 0, 0],
[0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1],
[0, 1, 1, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 1, 0, 1]]])