I have n x 3
ndarray and wanted to sort it by 3-rd column.In 3-rd column there MAY BE some np.inf
. Anyway, I used sorted
function. After that I used njit
decorator on my main function but got different result... So I wrote that code to check what's going on:
#Sorting Functions
import numpy as np
from numba import njit
@njit
def sort_me_numba(arr):
res = sorted(arr, key=lambda x: x[2], reverse=True)
return res
def sort_me_python(arr):
res = sorted(arr, key=lambda x: x[2], reverse=True)
return res
#Making data format I have
arr = np.concatenate([np.random.normal(loc=0.1, scale=0.005, size=1_000).reshape((-1, 1)) for i in range(3)], axis=1)
samples = [0, 14, 53, 344, 43, 654, 435, 33]
arr[samples, 2] = np.inf
print('-----------NUMBA with numpy array------------')
sort_me_numba(arr)
print('-----------PYTHON with numpy array------------')
sort_me_python(arr)
print('-----------PYTHON with list------------')
sort_me_python(list(arr))
And there is really some differences. But when samples = [0, 1, 2, 442]
- there is no differences. Why?
CodePudding user response:
Your function tell to Numba/CPython to sort lines based on the second item which contains Inf values. Two infinity values are considered equal so a sorting algorithm can change the order of the lines having Inf values in the second item of the target lines. The different results is due to different algorithms being used. This is expected because the sorting algorithm is not guaranteed to be stable (ie. preserve the order of equal lines).
To solve the problem, you need to use a stable algorithm. np.argsort
can be used to find the ordering of the value and the parameter kind
can be tuned so to choose a stable algorithm (called "stable"
). Additionally, np.argsort
should also be faster than sorted
on Numpy arrays because it does not cause many lambda calls nor compute Numpy lines as slow CPython objects. This should do the job:
arr[np.argsort(-arr[:,2], kind='stable')]