Home > Software engineering >  Sort a multi-dimensional tensor using another tensor
Sort a multi-dimensional tensor using another tensor

Time:08-04

I'm trying to sort C (see image) using R to get Sorted_C.

enter image description here

c = torch.tensor([[[0, 1, 0, 0, 0], [1, 0, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1]], [[0, 0, 1, 1, 0], [1, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0]]])
r = torch.tensor([[[0, 0, 0, 7.2, 0], [0, 25.4, 0, 0, 0], [0, 0, 43.6, 0, 0], [61.8, 0, 0, 0, 0], [0, 0, 0, 0, 80]], [[0, 0, 0, 0, 98.2], [116.4, 0, 0, 0, 0], [0, 134.6, 0, 0, 0], [0, 0, 152.8, 0, 0], [0, 0, 0, 169.2, 0]]])

# this is what I need
sorted_c = torch.tensor([[[0, 1, 0, 0, 0], [0, 0, 1, 1, 0], [1, 0, 0, 0, 0], [0, 0, 0, 0, 1]], [[0, 0, 0, 1, 1], [0, 1, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 1, 0, 1]]])

How do I do this efficiently?

correction: expected ordering for P6 to P10 should be: P10 -> P6 -> P7 -> P8 -> P9

CodePudding user response:

Minor note: it seems your example r doesn't quite match the values in your drawing. Also, it seems the result for p6 through p10 should be p10 < p6 < p7 < p8 < p9.


When you hear "advanced indexing", you should think torch.gather. That is: the resulting tensor comes from the indexing of c column-wise with some kind of tensor we will extract from r.

First we can sum and sort r to get the indices of the columns:

>>> idx = r.sum(1).argsort(1)
tensor([[3, 1, 2, 0, 4],
        [4, 0, 1, 2, 3]])

Then we can apply torch.Tensor.gather indexing c column-wise using the column indices contained in idx i.e. dim=2 is the one varying based on values in idx. Explicitly the resulting tensor out is constructed such that:

out[i][j][k] = c[i][j][idx[i][j][k]]

Keep in mind both the index tensor (idx) and the value tensor (c) must have same dimension sizes except for the one that we're indexing on, here dim=2. Therefore, we need to expand idx such that it has the shape of c. We can do so with None-indexing and using expand or expand_as:

>>> idx[:,None].expand_as(c)
tensor([[[3, 1, 2, 0, 4],
         [3, 1, 2, 0, 4],
         [3, 1, 2, 0, 4],
         [3, 1, 2, 0, 4]],

        [[4, 0, 1, 2, 3],
         [4, 0, 1, 2, 3],
         [4, 0, 1, 2, 3],
         [4, 0, 1, 2, 3]]])

Notice the duplicated values row-wise (fiy: they're not copies, expand is makes a view not a copy!)

Finally, we can gather the values in c to get the desired result:

>>> c.gather(2, idx[:,None].expand_as(c))
tensor([[[0, 1, 0, 0, 0],
         [0, 0, 1, 1, 0],
         [1, 0, 0, 0, 0],
         [0, 0, 0, 0, 1]],

        [[0, 0, 0, 1, 1],
         [0, 1, 1, 0, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 1, 0, 1]]])
  • Related