Home > database >  Calculate the average position of a value in an array fast method
Calculate the average position of a value in an array fast method

Time:04-15

I've got the following code to calculate the average position of 1's in a 2D numpy array that contains 1's and 0's. The issue is that it's very slow and I was wondering if a faster method is possible?

row_sum = 0
col_sum = 0
ones_count = 0

for row_count, row in enumerate(array):
    for col_count, col in enumerate(row):
        if col == 1:
            row_sum  = row_count
            col_sum  = col_count
            ones_count  = 1

average_position_ones = (row_count / ones_count, col_count / ones_count)

CodePudding user response:

Looking at your code you can get total sum of the array by np.sum() (provided the array contains only 0/1):

ones_count = array.sum()

print((arr.shape[0] - 1) / ones_count, (arr.shape[1] - 1) / ones_count)

CodePudding user response:

Here are 3 ways to be quicker at calculating row_sum, col_sum and ones_count.

Baseline

For testing I use this array

import numpy as np
import numba as nb

np.random.seed(1)

n = 10**4
array = np.random.randint(0,2,(n,n))

Now your exact code takes 20.3 s ± 397 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) on my machine.

The lazy one liner numpy version:

%timeit np.stack(np.where(array)).sum(axis=1),array.sum() takes 1.13 s ± 12.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) on my machine.

Here np.stack(np.where(array)).sum(axis=1) is what you call row_sum and col_sum and array.sum() gives your ones_count

Avoid looping threw twice

You can use your exact code with numba.jit

@nb.njit
def test():
    row_sum = 0
    col_sum = 0
    ones_count = 0

    for row_count, row in enumerate(array):
        for col_count, col in enumerate(row):
            if col == 1:
                row_sum  = row_count
                col_sum  = col_count
                ones_count  = 1

    return row_sum,col_sum,ones_count

%timeit test()

This is a bit faster. It takes 50 ms ± 614 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) on my machine. But defininitly not worth the effort.

Multicore version

A slight modification of your code can run multithreaded with numba

@nb.njit(parallel=True)
def test2():
    row_sum = 0
    col_sum = 0
    ones_count = 0
    
    for row_count in nb.prange(len(array)):
        row = array[row_count]
        for col_count, col in enumerate(row):
            if col == 1:
                row_sum  = row_count
                col_sum  = col_count
                ones_count  = 1

    return row_sum,col_sum,ones_count

%timeit test2()

Now this does give a little speed up compared to the lazy numpy version. It takes 13.3 ms ± 2.14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) on my 10 core machine.

  • Related