Home > Back-end >  pytorch tensor sort rows based on column
pytorch tensor sort rows based on column

Time:08-17

In a 2D tensor like so

tensor([[0.8771, 0.0976, 0.8186],
        [0.7044, 0.4783, 0.0350],
        [0.4239, 0.8341, 0.3693],
        [0.5568, 0.9175, 0.0763],
        [0.0876, 0.1651, 0.2776]])

How do you sort the rows based off the values in a column? For instance if we were to sort based off the last column, I would expect the rows to be such...

tensor([[0.7044, 0.4783, 0.0350],
        [0.5568, 0.9175, 0.0763],
        [0.0876, 0.1651, 0.2776],
        [0.4239, 0.8341, 0.3693],
        [0.8771, 0.0976, 0.8186]])

Values in the last column are now in ascending order.

CodePudding user response:

You can use a sorted and lambda function such as below. The sorting key is the last item in the list, x[-1]

tensor = [[0.8771, 0.0976, 0.8186],
        [0.7044, 0.4783, 0.0350],
        [0.4239, 0.8341, 0.3693],
        [0.5568, 0.9175, 0.0763],
        [0.0876, 0.1651, 0.2776]]
sorted(tensor,key=lambda x: x[-1])


Result: 

 [[0.7044, 0.4783, 0.035],
 [0.5568, 0.9175, 0.0763],
 [0.0876, 0.1651, 0.2776],
 [0.4239, 0.8341, 0.3693],
 [0.8771, 0.0976, 0.8186]]

CodePudding user response:

t = torch.rand(5, 3)
col_index_to_sort = 2

sorted_indices = t[:, col_index_to_sort].sort()[1]
t = t[sorted_indices]

CodePudding user response:

a = <your tensor>
ind = a[:,-1].argsort(dim=0)
a[ind]

argsort "Returns the indices that sort a tensor along a given dimension in ascending order by value." So, basically, you get sorting indices for the last column and reorder the rows according to these indices.

  • Related