Home > Net >  How to take out the same number of consecutive numbers as the sign of each number in [:,0] from np.n
How to take out the same number of consecutive numbers as the sign of each number in [:,0] from np.n

Time:06-24

I have an ndarray of shape (10,6), and I want to extract the number of numbers with the same sign as the corresponding dimension in [:,0]. My code is as follows, where generate_data is the demo data generated, and get_result is the production code, which needs to be run tens of millions of times:

import numpy as np

rand = np.random.default_rng(seed=0)


def generate_data() -> np.ndarray:
    data = rand.uniform(-1, 1, size=(10, 6))
    return data


def get_result(data) -> np.array:
    dim2 = data.shape[1]
    result = np.zeros(dim2, )
    data1 = np.sign(data)
    for i in range(dim2):
        a = data1[:, i]
        b = a[0]
        if b == 0:
            continue
        count = 1
        for j in range(1, a.shape[0]):
            if a[j] != b:
                result[i] = count * b
                break
            count  = 1
        result[i] = count * b
    return result


def main() -> None:
    data = generate_data()
    print(data)
    result = get_result(data)
    print(result)
    return


if __name__ == '__main__':
    main()


My data data is as follows:

print(data)
[[ 0.27392337 -0.46042657 -0.91805295 -0.96694473  0.62654048  0.82551115]
 [ 0.21327155  0.45899312  0.08724998  0.87014485  0.63170711 -0.994523  ]
 [ 0.71480855 -0.93282885  0.45931089 -0.64868876  0.72635784  0.08292244]
 [-0.40057622 -0.15462556 -0.94336066 -0.75143345  0.34124883  0.29437902]
 [ 0.23077022 -0.23264489  0.99441987  0.96167068  0.37108397  0.30091855]
 [ 0.37689346 -0.22215715 -0.72980699  0.44297668  0.05070864 -0.37951625]
 [-0.02832928  0.77897567  0.86808703 -0.28440961  0.14305966 -0.35626122]
 [ 0.18860006 -0.32417755 -0.216762    0.7805487  -0.54568481  0.24637429]
 [-0.83196931  0.6652883   0.57419661 -0.52126111  0.75296846 -0.88286393]
 [-0.32776588 -0.69944107 -0.09932127  0.59264854 -0.53871558 -0.8959574 ]]

The result I want to generate is as follows:

print(result)
[ 3. -1. -1. -1. 7. 1.]

On my computer, to evaluate the speed of get_result():

%timeit get_result(data)
23.2 µs ± 3.82 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

CodePudding user response:

If you want to speed-up your iterations, I highly recommend using njit from numba. I think the easiest way to achieve that is by just placing the njit decorator before your function. Since you are only working with arrays containing numeric data types, you do not need to change anything in your code. This alone should reduce your iterations significantly.

import numpy as np
from numba import njit

rand = np.random.default_rng(seed=0)

def generate_data() -> np.ndarray:
    data = rand.uniform(-1, 1, size=(10, 6))
    return data

@njit
def get_result_njit(data: np.ndarray) -> np.array:
    dim2 = data.shape[1]
    result = np.zeros(dim2, )
    data1 = np.sign(data)
    for i in range(dim2):
        a = data1[:, i]
        b = a[0]
        if b == 0:
            continue
        count = 1
        for j in range(1, a.shape[0]):
            if a[j] != b:
                result[i] = count * b
                break
            count  = 1
        result[i] = count * b
    return result

Comparison of your function with and without the njit decorator:

%timeit get_result_njit(data)
-----------------------------------------------------------------------------
1.21 µs ± 92.3 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
-----------------------------------------------------------------------------

%timeit get_result(data)
-----------------------------------------------------------------------------
15.3 µs ± 252 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
-----------------------------------------------------------------------------

Note: This is definitely not the limit, I am sure that you can still decrease run times, especially if you make fully use of numba. But I think that should be the easiest way to achieve a speed-up. If you need more performance, definitely check out numba.

  • Related