Home > other >  Repeating rows from array
Repeating rows from array

Time:04-23

I have a problem becase I would like to repeat n time all rows from array(X, Y) without using loops and get array(n*X, Y)

import jax.numpy as jnp

arr = jnp.array([[12, 14, 12, 0, 1],
                [0, 14, 12, 0, 1],
                [0, 0, 12, 0, 1]])
n = 3

result = jnp.array([[12 14 12 0 1],
                    [12 14 12 0 1],
                    [12 14 12 0 1],
                    [0 14 12 0 1],
                    [0 14 12 0 1],
                    [0 14 12 0 1],
                    [0 0 12 0 1],
                    [0 0 12 0 1],
                    [0 0 12 0 1]])

I haven't found any built-in method to perform this operation, tried with jnp.tile, jnp.repeat.

jnp.repeat

arr_r = jnp.repeat(arr, n, axis=1)

Output:
[[12 12 12 14 14 14 12 12 12  0  0  0  1  1  1]
 [ 0  0  0 14 14 14 12 12 12  0  0  0  1  1  1]
 [ 0  0  0  0  0  0 12 12 12  0  0  0  1  1  1]]

arr_t = jnp.tile(arr, n)

Output:
[[12 14 12  0  1 12 14 12  0  1 12 14 12  0  1]
 [ 0 14 12  0  1  0 14 12  0  1  0 14 12  0  1]
 [ 0  0 12  0  1  0  0 12  0  1  0  0 12  0  1]]

Maybe I may construct result array from array_t...

CodePudding user response:

You say you tried jnp.repeat but don't explain why it doesn't do what you want. I'm guessing you're neglecting the axis parameter:

jnp.repeat(arr, n, axis=0)
  • Related