I need to alternate Pytorch Tensors (similar to numpy arrays) with rows and columns of zeros. Like this:
Input => [[ 1,2,3],
[ 4,5,6],
[ 7,8,9]]
output => [[ 1,0,2,0,3],
[ 0,0,0,0,0],
[ 4,0,5,0,6],
[ 0,0,0,0,0],
[ 7,0,8,0,9]]
I am using the accepted answer in this question that proposes the following
def insert_zeros(a, N=1):
# a : Input array
# N : number of zeros to be inserted between consecutive rows and cols
out = np.zeros( (N 1)*np.array(a.shape)-N,dtype=a.dtype)
out[::N 1,::N 1] = a
return out
The answers works perfectly, except that I need to perform this many times on many arrays and the time it takes has become the bottleneck. It is the step-sized slicing that takes most of the time.
For what it's worth, the matrices I am using it for are 4D, an example size of a matrix is 32x18x16x16 and I am inserting the alternate rows/cols only in the last two dimensions.
So my question is, is there another implementation with the same functionality but with reduced time?
CodePudding user response:
I found a few methods to achieve this result, and the indexing method seems to be consistently the fastest.
There might be some improvement to be made on other methods though, because I tried to generalized them from 1D to 2D and arbitrary number of leading dimensions, and might not have do it in the best way posisble.
Edit: Yet another method using numpy, not faster.
Performance test (CPU):
In [4]: N, C, H, W = 11, 5, 128, 128
...: x = torch.rand(N, C, H, W)
...: k = 3
...:
...: x1 = interleave_index(x, k)
...: x2 = interleave_view(x, k)
...: x3 = interleave_einops(x, k)
...: x4 = interleave_convtranspose(x, k)
...: x4 = interleave_numpy(x, k)
...:
...: assert torch.all(x1 == x2)
...: assert torch.all(x2 == x3)
...: assert torch.all(x3 == x4)
...: assert torch.all(x4 == x5)
...:
...: %timeit interleave_index(x, k)
...: %timeit interleave_view(x, k)
...: %timeit interleave_einops(x, k)
...: %timeit interleave_convtranspose(x, k)
...: %timeit interleave_numpy(x, k)
9.51 ms ± 2.21 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
12.6 ms ± 4.98 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
23.3 ms ± 4.19 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
62.5 ms ± 19.5 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
50.6 ms ± 809 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Performance test (GPU):
(numpy metod not tested)
...: ...
...: x = torch.rand(N, C, H, W, device="cuda")
...: ...
260 µs ± 1.92 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
861 µs ± 6.77 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
912 µs ± 14.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
429 µs ± 5.08 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Implementations:
import torch
import torch.nn.functional as F
import einops
def interleave_index(x, k):
*cdims, Hin, Win = x.shape
Hout = (k 1) * (Hin - 1) 1
Wout = (k 1) * (Win - 1) 1
out = x.new_zeros(*cdims, Hout, Wout)
out[..., :: k 1, :: k 1] = x
return out
def interleave_view(x, k):
"""
From
https://discuss.pytorch.org/t/how-to-interleave-two-tensors-along-certain-dimension/11332/4
"""
*cdims, Hin, Win = x.shape
Hout = (k 1) * (Hin - 1) 1
Wout = (k 1) * (Win - 1) 1
zeros = [torch.zeros_like(x)] * k
out = torch.stack([x, *zeros], dim=-1).view(*cdims, Hin, Wout k)[..., :-k]
zeros = [torch.zeros_like(out)] * k
out = torch.stack([out, *zeros], dim=-2).view(*cdims, Hout k, Wout)[..., :-k, :]
return out
def interleave_einops(x, k):
"""
From
https://discuss.pytorch.org/t/how-to-interleave-two-tensors-along-certain-dimension/11332/6
"""
zeros = [torch.zeros_like(x)] * k
out = einops.rearrange([x, *zeros], "t ... h w -> ... h (w t)")[..., :-k]
zeros = [torch.zeros_like(out)] * k
out = einops.rearrange([out, *zeros], "t ... h w -> ... (h t) w")[..., :-k, :]
return out
def interleave_convtranspose(x, k):
"""
From
https://github.com/pytorch/pytorch/issues/7911#issuecomment-515493009
"""
C = x.shape[-3]
weight=x.new_ones(C, 1, 1, 1)
return F.conv_transpose2d(x, weight=weight, stride=k 1, groups=C)
def interleave_numpy(x, k):
"""
From https://stackoverflow.com/a/53179919
"""
pos = np.repeat(np.arange(1, x.shape[-1]), k)
out = np.insert(x, pos, 0, axis=-1)
pos = np.repeat(np.arange(1, x.shape[-2]), k)
out = np.insert(out, pos, 0, axis=-2)
return out
CodePudding user response:
Since you know the size of the array in advance, first step to optimize is to create the out
array outside the function. Then, try numba
to jit-compile the function and work in-place on the out
array. This achieves 5X speedup over the numpy
version you posted.
import numpy as np
from numba import njit
@njit
def insert_zeros_n(a, out, N=1):
for i in range(a.shape[0]):
for j in range(a.shape[1]):
out[2*i,2*j] = a[i,j]
and call it with the specified N
and a
:
N = 1
a = np.arange(16*16).reshape(16, 16)
out = np.zeros( (N 1)*np.array(a.shape)-N,dtype=a.dtype)
insert_zeros_n(a,out)
CodePudding user response:
I am not familiar to Pytorch, but to accelerate the code that you provided, I think JAX library will help a lot. So, if:
import numpy as np
import jax
import jax.numpy as jnp
from functools import partial
a = np.arange(10000).reshape(100, 100)
b = jnp.array(a)
@partial(jax.jit, static_argnums=1)
def new(a, N):
out = jnp.zeros( (N 1)*np.array(a.shape)-N,dtype=a.dtype)
out = out.at[::N 1,::N 1].set(a)
return out
will improve the runtime about 10 times on GPU. It depends to array size and N
(The increase in the sizes, the better performances). You can see Benchmarks on my Colab link based on the 4 answer proposed so far (JAX beats the others).
I believe that jax can be one of the best libraries for your case if you could adjust it on your problem (It is possible).
CodePudding user response:
Encapsulated for any N
, what about using numpy.kron
with 4D inputs,
a = np.arange(1, 19).reshape((1, 2, 3, 3))
print(a)
# array([[[[ 1, 2, 3],
# [ 4, 5, 6],
# [ 7, 8, 9]],
#
# [[10, 11, 12],
# [13, 14, 15],
# [16, 17, 18]]]])
def interleave_kron(a, N=1):
n = N 1
return np.kron(
a, np.hstack((1, np.zeros(pow(n, 2) - 1))).reshape((1, 1, n, n))
)[..., :-N, :-N]
where np.hstack((1, np.zeros(pow(n, 2) - 1))).reshape((1, 1, n, n))
could be externalized/defaulted once for all for the sake of performance.
and then
>>> interleave_kron(a, N=2)
array([[[[ 1., 0., 0., 2., 0., 0., 3.],
[ 0., 0., 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0., 0., 0.],
[ 4., 0., 0., 5., 0., 0., 6.],
[ 0., 0., 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0., 0., 0.],
[ 7., 0., 0., 8., 0., 0., 9.]],
[[10., 0., 0., 11., 0., 0., 12.],
[ 0., 0., 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0., 0., 0.],
[13., 0., 0., 14., 0., 0., 15.],
[ 0., 0., 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0., 0., 0.],
[16., 0., 0., 17., 0., 0., 18.]]]])
?