Home > database >  Numpy speed efficiency using broadcasting, transpose and reshape in large size array
Numpy speed efficiency using broadcasting, transpose and reshape in large size array

Time:10-23

Is there a way to speed up the following line of code:

desired_channel=32
len_indices=50000
fast_idx = np.broadcast_to(np.arange(desired_channel)[:, None], (desired_channel, len_indices)).T.reshape(-1)

Thank you.

CodePudding user response:

I am new to jax library. I have compared your code by jax one using the following code on Colab TPU:

import numpy as np
from jax import jit
import jax.numpy as jnp
import timeit

desired_channel=32
len_indices=50000

def ex_():
    return np.broadcast_to(np.arange(desired_channel)[:, None], (desired_channel, len_indices)).T.reshape(-1)
%timeit -n1000 -r10 ex_()

@jit
def exj_():
    return jnp.broadcast_to(jnp.arange(desired_channel)[:, None], (desired_channel, len_indices)).T.reshape(-1)
%timeit -n1000 -r10 exj_()

in one of my efforts, the results were as:

1000 loops, best of 10: 901 µs per loop
1000 loops, best of 10: 317 µs per loop

in this way, jax could speed up your code about two to three times.

CodePudding user response:

The last line of code is simply equal to np.tile(np.arange(desired_channel), len_indices).

On my machine, the performance of np.tile like many Numpy calls is bounded by the operating system (page faults), the memory allocator and the memory throughput. There are two ways to overcome this limitation: not to allocate/fill temporary buffers, to produce smaller arrays in memory using shorter types like np.uint8 or np.uint16 regarding your needs.

Since there is no out parameter for the np.tile function, Numba can be used to generate a fast alternative function. Here is an example:

import numba as nb

@nb.njit('int32[::1](int32, int32, int32[::1])', parallel=True)
def generate(desired_channel, len_indices, out):
    for i in nb.prange(len_indices):
        for j in range(desired_channel):
            out[i*desired_channel j] = j
    return out

desired_channel=32
len_indices=50000
buffer = np.full(desired_channel * len_indices, 0, dtype=np.int32)
%timeit -n 200 generate(desired_channel, len_indices, fast_idx)

Here are the performance results:

Original code: 1.25 ms
np.tile:       1.24 ms
Numba:         0.20 ms
  • Related