I know that we can sort the columns of an 2D numpy array based on a row the following way:
a = np.array([[1,4,7],
[3,1,5],
[9,5,8]])
a = a[:, a[1, :].argsort()]
Out: [[4,1,7],
[1,3,5],
[5,9,8]]
Please note that this is indeed what I want. The second row (index=1) is now sorted and the values in rows 0 and 2 also shifted accordingly. That is, the column positions change based on the sorting order of row 1.
But now to my problem: I don't have a 2D array but a 3D array (i.e. an array of 2D arrays).
a = np.array([[[1,4,7],
[3,1,5],
[9,5,8]],
[[2,8,7],
[3,8,1],
[9,2,8]]])
I still want to sort the columns of the 2D arrays, individually, based on the values of their respective rows 1. The desired result would be:
([[[4,1,7],
[1,3,5],
[5,9,8]],
[[7,2,8],
[1,3,8],
[8,9,2]]])
I tried the following but the results are not as desired:
a = a[:, :, a[: , 1, :].argsort()]
CodePudding user response:
Try np.take_along_axis
:
np.take_along_axis(a,a[:,1].argsort()[:,None], axis=2)
Out:
array([[[4, 1, 7],
[1, 3, 5],
[5, 9, 8]],
[[7, 2, 8],
[1, 3, 8],
[8, 9, 2]]])
Honestly, don't ask me why it works :-)
CodePudding user response:
You can use a combination of numpy.argsort
and numpy.take_along_axis
:
idx = np.argsort(a, axis=2)
np.take_along_axis(a, idx[:,None,1], axis=2)
It works by getting the sorting order from argsort
and then keeping only the relevant row (1 here), reshapes to broadcast the operation of take_along_axis
on all the other rows.
output:
array([[[4, 1, 7],
[1, 3, 5],
[5, 9, 8]],
[[7, 2, 8],
[1, 3, 8],
[8, 9, 2]]])