I have an ndarray of shape (10,6), and I want to extract the number of numbers with the same sign as the corresponding dimension in [:,0]. My code is as follows, where generate_data
is the demo data generated, and get_result
is the production code, which needs to be run tens of millions of times:
import numpy as np
rand = np.random.default_rng(seed=0)
def generate_data() -> np.ndarray:
data = rand.uniform(-1, 1, size=(10, 6))
return data
def get_result(data) -> np.array:
dim2 = data.shape[1]
result = np.zeros(dim2, )
data1 = np.sign(data)
for i in range(dim2):
a = data1[:, i]
b = a[0]
if b == 0:
continue
count = 1
for j in range(1, a.shape[0]):
if a[j] != b:
result[i] = count * b
break
count = 1
result[i] = count * b
return result
def main() -> None:
data = generate_data()
print(data)
result = get_result(data)
print(result)
return
if __name__ == '__main__':
main()
My data data is as follows:
print(data)
[[ 0.27392337 -0.46042657 -0.91805295 -0.96694473 0.62654048 0.82551115]
[ 0.21327155 0.45899312 0.08724998 0.87014485 0.63170711 -0.994523 ]
[ 0.71480855 -0.93282885 0.45931089 -0.64868876 0.72635784 0.08292244]
[-0.40057622 -0.15462556 -0.94336066 -0.75143345 0.34124883 0.29437902]
[ 0.23077022 -0.23264489 0.99441987 0.96167068 0.37108397 0.30091855]
[ 0.37689346 -0.22215715 -0.72980699 0.44297668 0.05070864 -0.37951625]
[-0.02832928 0.77897567 0.86808703 -0.28440961 0.14305966 -0.35626122]
[ 0.18860006 -0.32417755 -0.216762 0.7805487 -0.54568481 0.24637429]
[-0.83196931 0.6652883 0.57419661 -0.52126111 0.75296846 -0.88286393]
[-0.32776588 -0.69944107 -0.09932127 0.59264854 -0.53871558 -0.8959574 ]]
The result I want to generate is as follows:
print(result)
[ 3. -1. -1. -1. 7. 1.]
On my computer, to evaluate the speed of get_result()
:
%timeit get_result(data)
23.2 µs ± 3.82 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
CodePudding user response:
If you want to speed-up your iterations, I highly recommend using njit
from numba
. I think the easiest way to achieve that is by just placing the njit
decorator before your function. Since you are only working with arrays containing numeric data types, you do not need to change anything in your code. This alone should reduce your iterations significantly.
import numpy as np
from numba import njit
rand = np.random.default_rng(seed=0)
def generate_data() -> np.ndarray:
data = rand.uniform(-1, 1, size=(10, 6))
return data
@njit
def get_result_njit(data: np.ndarray) -> np.array:
dim2 = data.shape[1]
result = np.zeros(dim2, )
data1 = np.sign(data)
for i in range(dim2):
a = data1[:, i]
b = a[0]
if b == 0:
continue
count = 1
for j in range(1, a.shape[0]):
if a[j] != b:
result[i] = count * b
break
count = 1
result[i] = count * b
return result
Comparison of your function with and without the njit
decorator:
%timeit get_result_njit(data)
-----------------------------------------------------------------------------
1.21 µs ± 92.3 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
-----------------------------------------------------------------------------
%timeit get_result(data)
-----------------------------------------------------------------------------
15.3 µs ± 252 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
-----------------------------------------------------------------------------
Note: This is definitely not the limit, I am sure that you can still decrease run times, especially if you make fully use of numba
. But I think that should be the easiest way to achieve a speed-up. If you need more performance, definitely check out numba.