Home > Back-end >  fastest way to filter a 2d numpy array
fastest way to filter a 2d numpy array

Time:09-23

I am trying to filter a numpy array of array, I have done a function like the following:

@nb.njit
def numpy_filter (npX):
    n = np.full (npX.shape[0], True)

    for npo_index in range(npX.shape[0]):
        n[npo_index] = npX[npo_index][0] < 2000 and npX[npo_index][1] < 4000 and npX[npo_index][2] < 5000

    return npX[n]

It took 1.75s (numba njit mode) for len of array = 600K , while it only take < 0.5s for list [x for x in obj1 if x[0] < 2000 and x[1] < 4000 and x[2] < 5000]

Is there any better implementation could have a filtering function that could make it run faster?

CodePudding user response:

Generally with Pandas/NumPy arrays, you'll get the best performance if you

  • avoid iterating over the array
  • only create soft copies or views of the base array
  • create a minimal number of intermediate Python objects

Pandas is probably your friend here, allowing you to create a view of the backing NumPy arrays and operate on the rows of each via a shared index

starting data

This creates a random array of the same shape as your source data, with values range from 0-10000

>>> import numpy as np
>>> arr = np.random.rand(600000, 3) * 10000
>>> arr
array([[8079.54193993,  925.74430028, 2031.45569251],
       [8232.74161149, 2347.42814063, 7571.21287502],
       [7435.52165567,  756.74380534, 1023.12181186],
       ...,
       [2176.36643662, 5374.36584708,  637.43482263],
       [2645.0737415 , 9059.42475818, 3913.32941652],
       [3626.54923011, 1494.57126083, 6121.65034039]])

create a Pandas DataFrame

This creates view over your source data so you can work with all the rows together using a shared index

>>> import pandas as pd
>>> df = pd.DataFrame(arr)
>>> df
                  0            1            2
0       8079.541940   925.744300  2031.455693
1       8232.741611  2347.428141  7571.212875
2       7435.521656   756.743805  1023.121812
3       4423.799649  2256.125276  7591.732828
4       6892.019075  3170.699818  1625.226953
...             ...          ...          ...
599995   642.104686  3164.107206  9508.818253
599996   102.819102  3068.249711  1299.341425
599997  2176.366437  5374.365847   637.434823
599998  2645.073741  9059.424758  3913.329417
599999  3626.549230  1494.571261  6121.650340

[600000 rows x 3 columns]

filter

This gets a filtered view of the index for each columns and uses the combined result filter the DataFrame

>>> df[(df[0] < 2000) & (df[1] < 4000) & (df[2] < 5000)]
                  0            1            2
35      1829.777633  1333.083450  1928.982210
38       653.584288  3129.089395  4753.734920
71      1354.736876   279.202816     5.793797
97      1381.531847   551.465381  3767.436640
115      183.112455  1573.272310  1973.143995
...             ...          ...          ...
599963  1895.537096  1695.569792  1866.575164
599970  1061.011239    51.534961  1014.290040
599988  1780.535714  2311.671494  1012.828410
599994   878.643910   352.858091  3014.505666
599996   102.819102  3068.249711  1299.341425

[24067 rows x 3 columns]

benchmark maybe to follow, but it's very fast

CodePudding user response:

The jit function has not been warmed up, after the first run, the result shows it only takes 0.07s to finish the task.

  • Related