Home > Net >  Given batch of samples find x,y positions - python, jax
Given batch of samples find x,y positions - python, jax

Time:01-17

I would like to find the positions of where the is a '1' in the following batch of sampled array with dimension [batch, 4,4] = [2,4,4].

import jax
import jax.numpy as jnp

a = jnp.array([[[0., 0., 0., 1.],
                [0., 0., 0., 0.],
                [0., 1., 0., 1.],
                [0., 0., 1., 1.]],
             
               [[1., 0., 1., 0.],
                [1., 0., 0., 0.],
                [0., 0., 0., 0.],
                [0., 1., 0., 1.]]])

I tried going through the dimension of batches (with vmap) and use the jax function to find the coordinates with

b = jax.vmap(jnp.where)(a)
print('b', b)

but I get an error that I don't know how to fix:

The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
This Tracer was created on line /home/imi/Desktop/Backflow/backflow/src/debug.py:17 (<module>)

I expect the following output:

b = [[[0,3], [2,1],[2,3],[3,2],[3,3]],

     [[0,0],[0,2],[1,0],[3,1],[3,3]]

The first line of [x,y] coordinates correspond to the positions of where there is a '1' in the first batch, and for the second line in the second batch.

CodePudding user response:

JAX transformations like vmap require arrays to be statically-sized, so there is no way to do exactly the computation you have in mind (because the number of 1 entries, and thus the size of the output array, is data-dependent).

But if you know a priori that there are five entries per batch, you can do something like this:

from functools import partial
indices = jax.vmap(partial(jnp.where, size=5))(a)
print(jnp.stack(indices, axis=2))
[[[0 3]
  [2 1]
  [2 3]
  [3 2]
  [3 3]]

 [[0 0]
  [0 2]
  [1 0]
  [3 1]
  [3 3]]]

If you don't know a priori how many 1 entries there are, then you have a few options: one is to avoid JAX transformations and call an un-transformed jnp.where on each batch:

result = [jnp.column_stack(jnp.where(b)) for b in a]
print(result)
[DeviceArray([[0, 3],
             [2, 1],
             [2, 3],
             [3, 2],
             [3, 3]], dtype=int32), DeviceArray([[0, 0],
             [0, 2],
             [1, 0],
             [3, 1],
             [3, 3]], dtype=int32)]

Note that for this case, it's not possible in general to store the results in a single array, because there may be different numbers of 1 entries in each batch, and JAX does not support ragged arrays.

The other option is to set the size to some maximum value, and output padded results:

max_size = a[0].size  # size of slice is the upper bound
fill_value = a[0].shape  # fill with out-of-bound indices
indices = jax.vmap(partial(jnp.where, size=max_size, fill_value=fill_value))(a)
print(jnp.stack(indices, axis=2))
[[[0 3]
  [2 1]
  [2 3]
  [3 2]
  [3 3]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]]

 [[0 0]
  [0 2]
  [1 0]
  [3 1]
  [3 3]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]]]

With padded results, you could then write the remainder of your code to anticipate these padded values.

  • Related