I am trying to filter a numpy array of array, I have done a function like the following:
@nb.njit
def numpy_filter (npX):
n = np.full (npX.shape[0], True)
for npo_index in range(npX.shape[0]):
n[npo_index] = npX[npo_index][0] < 2000 and npX[npo_index][1] < 4000 and npX[npo_index][2] < 5000
return npX[n]
It took 1.75s (numba njit mode) for len of array = 600K , while it only take < 0.5s for list [x for x in obj1 if x[0] < 2000 and x[1] < 4000 and x[2] < 5000]
Is there any better implementation could have a filtering function that could make it run faster?
CodePudding user response:
Generally with Pandas/NumPy arrays, you'll get the best performance if you
- avoid iterating over the array
- only create soft copies or views of the base array
- create a minimal number of intermediate Python objects
Pandas is probably your friend here, allowing you to create a view of the backing NumPy arrays and operate on the rows of each via a shared index
starting data
This creates a random array of the same shape as your source data, with values range from 0-10000
>>> import numpy as np
>>> arr = np.random.rand(600000, 3) * 10000
>>> arr
array([[8079.54193993, 925.74430028, 2031.45569251],
[8232.74161149, 2347.42814063, 7571.21287502],
[7435.52165567, 756.74380534, 1023.12181186],
...,
[2176.36643662, 5374.36584708, 637.43482263],
[2645.0737415 , 9059.42475818, 3913.32941652],
[3626.54923011, 1494.57126083, 6121.65034039]])
create a Pandas DataFrame
This creates view over your source data so you can work with all the rows together using a shared index
>>> import pandas as pd
>>> df = pd.DataFrame(arr)
>>> df
0 1 2
0 8079.541940 925.744300 2031.455693
1 8232.741611 2347.428141 7571.212875
2 7435.521656 756.743805 1023.121812
3 4423.799649 2256.125276 7591.732828
4 6892.019075 3170.699818 1625.226953
... ... ... ...
599995 642.104686 3164.107206 9508.818253
599996 102.819102 3068.249711 1299.341425
599997 2176.366437 5374.365847 637.434823
599998 2645.073741 9059.424758 3913.329417
599999 3626.549230 1494.571261 6121.650340
[600000 rows x 3 columns]
filter
This gets a filtered view of the index for each columns and uses the combined result filter the DataFrame
>>> df[(df[0] < 2000) & (df[1] < 4000) & (df[2] < 5000)]
0 1 2
35 1829.777633 1333.083450 1928.982210
38 653.584288 3129.089395 4753.734920
71 1354.736876 279.202816 5.793797
97 1381.531847 551.465381 3767.436640
115 183.112455 1573.272310 1973.143995
... ... ... ...
599963 1895.537096 1695.569792 1866.575164
599970 1061.011239 51.534961 1014.290040
599988 1780.535714 2311.671494 1012.828410
599994 878.643910 352.858091 3014.505666
599996 102.819102 3068.249711 1299.341425
[24067 rows x 3 columns]
benchmark maybe to follow, but it's very fast
CodePudding user response:
The jit function has not been warmed up, after the first run, the result shows it only takes 0.07s to finish the task.