I have a problem becase I would like to repeat n time all rows from array(X, Y) without using loops and get array(n*X, Y)
import jax.numpy as jnp
arr = jnp.array([[12, 14, 12, 0, 1],
[0, 14, 12, 0, 1],
[0, 0, 12, 0, 1]])
n = 3
result = jnp.array([[12 14 12 0 1],
[12 14 12 0 1],
[12 14 12 0 1],
[0 14 12 0 1],
[0 14 12 0 1],
[0 14 12 0 1],
[0 0 12 0 1],
[0 0 12 0 1],
[0 0 12 0 1]])
I haven't found any built-in method to perform this operation, tried with jnp.tile, jnp.repeat.
jnp.repeat
arr_r = jnp.repeat(arr, n, axis=1)
Output:
[[12 12 12 14 14 14 12 12 12 0 0 0 1 1 1]
[ 0 0 0 14 14 14 12 12 12 0 0 0 1 1 1]
[ 0 0 0 0 0 0 12 12 12 0 0 0 1 1 1]]
arr_t = jnp.tile(arr, n)
Output:
[[12 14 12 0 1 12 14 12 0 1 12 14 12 0 1]
[ 0 14 12 0 1 0 14 12 0 1 0 14 12 0 1]
[ 0 0 12 0 1 0 0 12 0 1 0 0 12 0 1]]
Maybe I may construct result array from array_t...
CodePudding user response:
You say you tried jnp.repeat
but don't explain why it doesn't do what you want. I'm guessing you're neglecting the axis
parameter:
jnp.repeat(arr, n, axis=0)