I have three lists that I am using a sum of a list comprehension, however, having these lists of length n >= 1500 I've been unable to make my code run any more efficiently than ~3s per list comprehension. This code needs to run thousands of times, so 3s per does not cut it.
Below is what my current attempt looks like. The split
is just a float determined earlier in my code.
sum([list1[k] * (list2[k] == 1) if list3[k] < split else list1[k] * (list2[k] == -1) for k in range(n)])
list1
contains 1500 positive floats between 0 and 1 which sum to 1.
list2
contains 1500 randomly sampled -1's and 1's.
list3
contains 1500 randomly sampled values from a normal distribution, an example would be np.random.normal(5, 0.5, 3)
.
CodePudding user response:
I ended up writing three approaches to your question: improved python, numpy, and numba.
The improved python version based on @KellyBelly's comment works nicely.
zip
has a surprisingly strong effect on performance here.With numpy, you want to leverage the power of vectorized operations, turn your conditions into masks and get rid of loops entirely.
Numba is usually the fastest solution if you feel at ease with its important concepts (
njit
,prange
, etc.). It takes a bit more proof-reading than the numpy approach but it's well rewarded.
Note that those are only different ways of implementing the same algorithm. Improving an inefficient algorithm is very important too if you are chasing those precious milliseconds.
Timings:
Items | List comprehension | Zipped iterator | Numpy arrays | Numba.njit | Numba.njit(parallel=True) |
---|---|---|---|---|---|
1 k | 0.191 ms | 0.129 ms | 0.487 ms | 0.006 ms | 0.013 ms |
10 k | 2.288 ms | 1.206 ms | 0.477 ms | 0.048 ms | 0.019 ms |
100 k | 18.941 ms | 13.245 ms | 2.857 ms | 0.477 ms | 0.056 ms |
Code:
# Imports.
import numba as nb
import numpy as np
np.random.seed(0)
# Data.
N = 100000
SPLIT = 50
array1 = np.random.randint(0, 100, N)
array2 = np.random.choice(( 1, -1), N)
array3 = np.random.randint(0, 100, N)
list1, list2, list3 = map(lambda a: a.tolist(), (array1, array2, array3))
print(N)
# Helpful timing function.
from contextlib import contextmanager
import time
@contextmanager
def time_this():
t0 = time.perf_counter()
yield
dt = time.perf_counter() - t0
print(f"{dt*1000:.3f} ms")
# List comprehension.
def list_comprehension():
n = len(list1)
return sum([list1[k] * (list2[k] == 1) if list3[k] < SPLIT else list1[k] * (list2[k] == -1) for k in range(n)])
# Zipped iterator.
def zipped_iterator():
return sum(l1 if l2 == (1 if l3 < SPLIT else -1) else 0 for l1, l2, l3 in zip(list1, list2, list3))
# Numpy array.
def numpy_arrays():
mask = array3 < SPLIT
positives = array1[mask] * (array2[mask] == 1)
negatives = array1[~mask] * (array2[~mask] == -1)
return positives.sum() negatives.sum()
# Numba.
@nb.njit
def numba_count():
total = 0
n = len(array1)
for k in nb.prange(n):
if array3[k] < SPLIT:
sign = 1
else:
sign = -1
if array2[k] == sign:
total = array1[k]
return total
# Numba in parallel.
@nb.njit(parallel=True)
def numba_count2():
total = 0
n = len(array1)
for k in nb.prange(n):
if array3[k] < SPLIT:
sign = 1
else:
sign = -1
if array2[k] == sign:
total = array1[k]
return total
# Timings.
totals = []
with time_this():
totals.append(list_comprehension())
with time_this():
totals.append(zipped_iterator())
with time_this():
totals.append(numpy_arrays())
numba_count() # Compile before we time anything.
with time_this():
totals.append(numba_count())
numba_count2() # Compile before we time anything.
with time_this():
totals.append(numba_count2())
# Assert that all the returned values are identical.
assert np.isclose(totals, totals[0]).all()