Home > Blockchain >  Efficiently insert alternate rows and columns in Python
Efficiently insert alternate rows and columns in Python

Time:05-21

I need to alternate Pytorch Tensors (similar to numpy arrays) with rows and columns of zeros. Like this:

Input => [[ 1,2,3],
           [ 4,5,6],
           [ 7,8,9]]

output => [[ 1,0,2,0,3],
           [ 0,0,0,0,0],
           [ 4,0,5,0,6],
           [ 0,0,0,0,0],
           [ 7,0,8,0,9]] 

I am using the accepted answer in this question that proposes the following

def insert_zeros(a, N=1):
    # a : Input array
    # N : number of zeros to be inserted between consecutive rows and cols 
    out = np.zeros( (N 1)*np.array(a.shape)-N,dtype=a.dtype)
    out[::N 1,::N 1] = a
    return out

The answers works perfectly, except that I need to perform this many times on many arrays and the time it takes has become the bottleneck. It is the step-sized slicing that takes most of the time.

For what it's worth, the matrices I am using it for are 4D, an example size of a matrix is 32x18x16x16 and I am inserting the alternate rows/cols only in the last two dimensions.

So my question is, is there another implementation with the same functionality but with reduced time?

CodePudding user response:

I found a few methods to achieve this result, and the indexing method seems to be consistently the fastest.

There might be some improvement to be made on other methods though, because I tried to generalized them from 1D to 2D and arbitrary number of leading dimensions, and might not have do it in the best way posisble.

Edit: Yet another method using numpy, not faster.

Performance test (CPU):

In [4]: N, C, H, W = 11, 5, 128, 128
   ...: x = torch.rand(N, C, H, W)
   ...: k = 3
   ...:
   ...: x1 = interleave_index(x, k)
   ...: x2 = interleave_view(x, k)
   ...: x3 = interleave_einops(x, k)
   ...: x4 = interleave_convtranspose(x, k)
   ...: x4 = interleave_numpy(x, k)
   ...:
   ...: assert torch.all(x1 == x2)
   ...: assert torch.all(x2 == x3)
   ...: assert torch.all(x3 == x4)
   ...: assert torch.all(x4 == x5)
   ...:
   ...: %timeit interleave_index(x, k)
   ...: %timeit interleave_view(x, k)
   ...: %timeit interleave_einops(x, k)
   ...: %timeit interleave_convtranspose(x, k)
   ...: %timeit interleave_numpy(x, k)

9.51 ms ± 2.21 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
12.6 ms ± 4.98 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
23.3 ms ± 4.19 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
62.5 ms ± 19.5 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
50.6 ms ± 809 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Performance test (GPU):

(numpy metod not tested)

...: ...
...: x = torch.rand(N, C, H, W, device="cuda")
...: ...
260 µs ± 1.92 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
861 µs ± 6.77 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
912 µs ± 14.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
429 µs ± 5.08 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Implementations:

import torch
import torch.nn.functional as F
import einops


def interleave_index(x, k):
    *cdims, Hin, Win = x.shape
    Hout = (k   1) * (Hin - 1)   1
    Wout = (k   1) * (Win - 1)   1
    out = x.new_zeros(*cdims, Hout, Wout)
    out[..., :: k   1, :: k   1] = x
    return out


def interleave_view(x, k):
    """
    From
    https://discuss.pytorch.org/t/how-to-interleave-two-tensors-along-certain-dimension/11332/4
    """
    *cdims, Hin, Win = x.shape
    Hout = (k   1) * (Hin - 1)   1
    Wout = (k   1) * (Win - 1)   1
    zeros = [torch.zeros_like(x)] * k
    out = torch.stack([x, *zeros], dim=-1).view(*cdims, Hin, Wout   k)[..., :-k]
    zeros = [torch.zeros_like(out)] * k
    out = torch.stack([out, *zeros], dim=-2).view(*cdims, Hout   k, Wout)[..., :-k, :]
    return out


def interleave_einops(x, k):
    """
    From
    https://discuss.pytorch.org/t/how-to-interleave-two-tensors-along-certain-dimension/11332/6
    """
    zeros = [torch.zeros_like(x)] * k
    out = einops.rearrange([x, *zeros], "t ... h w -> ... h (w t)")[..., :-k]
    zeros = [torch.zeros_like(out)] * k
    out = einops.rearrange([out, *zeros], "t ... h w -> ... (h t) w")[..., :-k, :]
    return out


def interleave_convtranspose(x, k):
    """
    From
    https://github.com/pytorch/pytorch/issues/7911#issuecomment-515493009
    """
    C = x.shape[-3]
    weight=x.new_ones(C, 1, 1, 1)
    return F.conv_transpose2d(x, weight=weight, stride=k 1, groups=C)


def interleave_numpy(x, k):
    """
    From https://stackoverflow.com/a/53179919
    """
    pos = np.repeat(np.arange(1, x.shape[-1]), k)
    out = np.insert(x, pos, 0, axis=-1)
    pos = np.repeat(np.arange(1, x.shape[-2]), k)
    out = np.insert(out, pos, 0, axis=-2)
    return out

CodePudding user response:

Since you know the size of the array in advance, first step to optimize is to create the out array outside the function. Then, try numba to jit-compile the function and work in-place on the out array. This achieves 5X speedup over the numpy version you posted.

import numpy as np
from numba import njit

@njit
def insert_zeros_n(a, out, N=1):
    for i in range(a.shape[0]):
        for j in range(a.shape[1]):
            out[2*i,2*j] = a[i,j]

and call it with the specified N and a:

N = 1
a = np.arange(16*16).reshape(16, 16)
out = np.zeros( (N 1)*np.array(a.shape)-N,dtype=a.dtype)
insert_zeros_n(a,out)

CodePudding user response:

I am not familiar to Pytorch, but to accelerate the code that you provided, I think JAX library will help a lot. So, if:

import numpy as np
import jax
import jax.numpy as jnp
from functools import partial

a = np.arange(10000).reshape(100, 100)
b = jnp.array(a)

@partial(jax.jit, static_argnums=1)
def new(a, N):
    out = jnp.zeros( (N 1)*np.array(a.shape)-N,dtype=a.dtype)
    out = out.at[::N 1,::N 1].set(a)
    return out

will improve the runtime about 10 times on GPU. It depends to array size and N (The increase in the sizes, the better performances). You can see Benchmarks on my Colab link based on the 4 answer proposed so far (JAX beats the others).
I believe that jax can be one of the best libraries for your case if you could adjust it on your problem (It is possible).

CodePudding user response:

Encapsulated for any N, what about using numpy.kron with 4D inputs,

a = np.arange(1, 19).reshape((1, 2, 3, 3))
print(a)
# array([[[[ 1,  2,  3],
#          [ 4,  5,  6],
#          [ 7,  8,  9]],
# 
#         [[10, 11, 12],
#          [13, 14, 15],
#          [16, 17, 18]]]])


def interleave_kron(a, N=1):
    n = N   1
    return np.kron(
        a, np.hstack((1, np.zeros(pow(n, 2) - 1))).reshape((1, 1, n, n))
    )[..., :-N, :-N]

where np.hstack((1, np.zeros(pow(n, 2) - 1))).reshape((1, 1, n, n)) could be externalized/defaulted once for all for the sake of performance.

and then

>>> interleave_kron(a, N=2)
array([[[[ 1.,  0.,  0.,  2.,  0.,  0.,  3.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 4.,  0.,  0.,  5.,  0.,  0.,  6.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 7.,  0.,  0.,  8.,  0.,  0.,  9.]],

        [[10.,  0.,  0., 11.,  0.,  0., 12.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [13.,  0.,  0., 14.,  0.,  0., 15.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [16.,  0.,  0., 17.,  0.,  0., 18.]]]])

?

  • Related