I've got the following code to calculate the average position of 1's in a 2D numpy array that contains 1's and 0's. The issue is that it's very slow and I was wondering if a faster method is possible?
row_sum = 0
col_sum = 0
ones_count = 0
for row_count, row in enumerate(array):
for col_count, col in enumerate(row):
if col == 1:
row_sum = row_count
col_sum = col_count
ones_count = 1
average_position_ones = (row_count / ones_count, col_count / ones_count)
CodePudding user response:
Looking at your code you can get total sum of the array by np.sum()
(provided the array contains only 0/1):
ones_count = array.sum()
print((arr.shape[0] - 1) / ones_count, (arr.shape[1] - 1) / ones_count)
CodePudding user response:
Here are 3 ways to be quicker at calculating row_sum
, col_sum
and ones_count
.
Baseline
For testing I use this array
import numpy as np
import numba as nb
np.random.seed(1)
n = 10**4
array = np.random.randint(0,2,(n,n))
Now your exact code takes 20.3 s ± 397 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
on my machine.
The lazy one liner numpy version:
%timeit np.stack(np.where(array)).sum(axis=1),array.sum()
takes 1.13 s ± 12.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
on my machine.
Here np.stack(np.where(array)).sum(axis=1)
is what you call row_sum
and col_sum
and array.sum()
gives your ones_count
Avoid looping threw twice
You can use your exact code with numba.jit
@nb.njit
def test():
row_sum = 0
col_sum = 0
ones_count = 0
for row_count, row in enumerate(array):
for col_count, col in enumerate(row):
if col == 1:
row_sum = row_count
col_sum = col_count
ones_count = 1
return row_sum,col_sum,ones_count
%timeit test()
This is a bit faster. It takes 50 ms ± 614 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
on my machine. But defininitly not worth the effort.
Multicore version
A slight modification of your code can run multithreaded with numba
@nb.njit(parallel=True)
def test2():
row_sum = 0
col_sum = 0
ones_count = 0
for row_count in nb.prange(len(array)):
row = array[row_count]
for col_count, col in enumerate(row):
if col == 1:
row_sum = row_count
col_sum = col_count
ones_count = 1
return row_sum,col_sum,ones_count
%timeit test2()
Now this does give a little speed up compared to the lazy numpy
version. It takes 13.3 ms ± 2.14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
on my 10 core machine.