Is there a way to speed up the following line of code:
desired_channel=32
len_indices=50000
fast_idx = np.broadcast_to(np.arange(desired_channel)[:, None], (desired_channel, len_indices)).T.reshape(-1)
Thank you.
CodePudding user response:
I am new to jax library. I have compared your code by jax one using the following code on Colab TPU:
import numpy as np
from jax import jit
import jax.numpy as jnp
import timeit
desired_channel=32
len_indices=50000
def ex_():
return np.broadcast_to(np.arange(desired_channel)[:, None], (desired_channel, len_indices)).T.reshape(-1)
%timeit -n1000 -r10 ex_()
@jit
def exj_():
return jnp.broadcast_to(jnp.arange(desired_channel)[:, None], (desired_channel, len_indices)).T.reshape(-1)
%timeit -n1000 -r10 exj_()
in one of my efforts, the results were as:
1000 loops, best of 10: 901 µs per loop
1000 loops, best of 10: 317 µs per loop
in this way, jax could speed up your code about two to three times.
CodePudding user response:
The last line of code is simply equal to np.tile(np.arange(desired_channel), len_indices)
.
On my machine, the performance of np.tile
like many Numpy calls is bounded by the operating system (page faults), the memory allocator and the memory throughput. There are two ways to overcome this limitation: not to allocate/fill temporary buffers, to produce smaller arrays in memory using shorter types like np.uint8
or np.uint16
regarding your needs.
Since there is no out
parameter for the np.tile
function, Numba can be used to generate a fast alternative function. Here is an example:
import numba as nb
@nb.njit('int32[::1](int32, int32, int32[::1])', parallel=True)
def generate(desired_channel, len_indices, out):
for i in nb.prange(len_indices):
for j in range(desired_channel):
out[i*desired_channel j] = j
return out
desired_channel=32
len_indices=50000
buffer = np.full(desired_channel * len_indices, 0, dtype=np.int32)
%timeit -n 200 generate(desired_channel, len_indices, fast_idx)
Here are the performance results:
Original code: 1.25 ms
np.tile: 1.24 ms
Numba: 0.20 ms