For example, I have:
arr = np.array[[10 30 20],
[30 20 10]]
indices = np.array[[2 1 0],
[2 1 0]]
I want:
[[20 30 10],
[10 20 30]]
Thank you very much!!
CodePudding user response:
Use np.take_along_axis
:
import numpy as np
arr = np.array([[10, 30, 20],
[30, 20, 10]])
indices = np.array([[2, 1, 0],
[2, 1, 0]])
res = np.take_along_axis(arr, indices, axis=1)
print(res)
Output
[[20 30 10]
[10 20 30]]