I would like to find the positions of where the is a '1' in the following batch of sampled array with dimension [batch, 4,4] = [2,4,4].
import jax
import jax.numpy as jnp
a = jnp.array([[[0., 0., 0., 1.],
[0., 0., 0., 0.],
[0., 1., 0., 1.],
[0., 0., 1., 1.]],
[[1., 0., 1., 0.],
[1., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 1., 0., 1.]]])
I tried going through the dimension of batches (with vmap) and use the jax function to find the coordinates with
b = jax.vmap(jnp.where)(a)
print('b', b)
but I get an error that I don't know how to fix:
The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
This Tracer was created on line /home/imi/Desktop/Backflow/backflow/src/debug.py:17 (<module>)
I expect the following output:
b = [[[0,3], [2,1],[2,3],[3,2],[3,3]],
[[0,0],[0,2],[1,0],[3,1],[3,3]]
The first line of [x,y] coordinates correspond to the positions of where there is a '1' in the first batch, and for the second line in the second batch.
CodePudding user response:
JAX transformations like vmap
require arrays to be statically-sized, so there is no way to do exactly the computation you have in mind (because the number of 1
entries, and thus the size of the output array, is data-dependent).
But if you know a priori that there are five entries per batch, you can do something like this:
from functools import partial
indices = jax.vmap(partial(jnp.where, size=5))(a)
print(jnp.stack(indices, axis=2))
[[[0 3]
[2 1]
[2 3]
[3 2]
[3 3]]
[[0 0]
[0 2]
[1 0]
[3 1]
[3 3]]]
If you don't know a priori how many 1
entries there are, then you have a few options: one is to avoid JAX transformations and call an un-transformed jnp.where
on each batch:
result = [jnp.column_stack(jnp.where(b)) for b in a]
print(result)
[DeviceArray([[0, 3],
[2, 1],
[2, 3],
[3, 2],
[3, 3]], dtype=int32), DeviceArray([[0, 0],
[0, 2],
[1, 0],
[3, 1],
[3, 3]], dtype=int32)]
Note that for this case, it's not possible in general to store the results in a single array, because there may be different numbers of 1
entries in each batch, and JAX does not support ragged arrays.
The other option is to set the size
to some maximum value, and output padded results:
max_size = a[0].size # size of slice is the upper bound
fill_value = a[0].shape # fill with out-of-bound indices
indices = jax.vmap(partial(jnp.where, size=max_size, fill_value=fill_value))(a)
print(jnp.stack(indices, axis=2))
[[[0 3]
[2 1]
[2 3]
[3 2]
[3 3]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]]
[[0 0]
[0 2]
[1 0]
[3 1]
[3 3]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]]]
With padded results, you could then write the remainder of your code to anticipate these padded values.