Home > Blockchain >  Construct array by sampling over every n'th element along last axis
Construct array by sampling over every n'th element along last axis

Time:11-20

Let a be some (not necessarily one-dimensional) NumPy array with n * m elements along its last axis. I wish to "split" this array along its last axis so that I take every n'th element starting from 0 up until n.

To be explicit let a have shape (k, n * m) then I wish to construct the array of shape (n, k, m)

np.array([a[:, i::n] for i in range(n)])

my problem is that though this indeed return the array that I seek, I still feel that there might be a more efficient and neat NumPy routine for this.

Cheers!

CodePudding user response:

I think this does what you want, without loops. I tested for 2D inputs, it may need some adjustments for more dimensions.

indexes = np.arange(0, a.size*n, n)   np.repeat(np.arange(n), a.size/n)
np.take(a, indexes, mode='wrap').reshape(n, a.shape[0], -1)

In my testing it is a bit slower than your original list solution.

CodePudding user response:

This is hard to write a faster Numpy implementation. One efficient solution is to use Numba so to speed this up. That being said, the memory access pattern can be the main reason why the code is slow on relatively large matrices. As a result, one need to care about the iteration order so the accesses can be relatively cache-friendly. Moreover, for large arrays, it can be a good idea to use multiple threads so to better mitigate the overhead coming from the relatively-high memory latency (due to the memory access pattern). Here is an implementation:

import numba as nb

# The first call is slower due to the build.
# Please consider specifying the signature of the function (ie. input types)
# to precompile the function ahead of time.
@nb.njit # Use nb.njit(parallel=True) for the parallel version
def compute(arr, n):
    k, m = arr.shape[0], arr.shape[1] // n
    assert arr.shape[1] == n * m

    out = np.empty((n, k, m), dtype=arr.dtype)

    # Use nb.prange for the parallel version
    for i2 in range(k):
        for i1 in range(n):
            outView = out[i1, i2]
            inView = a[i2]
            cur = i1
            for i3 in range(m):
                outView[i3] = inView[cur]
                cur  = n

    return out

Here is the results on my machine with a i5-9600KF processor (6-cores) for k=37, n=42, m=53 and a.dtype=np.int32:

John Zwinck's solution:    986.1 µs
Initial implementation:     91.7 µs
Sequential Numba:           62.9 µs
Parallel Numba:             14.7 µs
Optimal lower-bound:        ~7.0 µs
  • Related