Home > Mobile >  Vectorize jax.lax.cond with vmap
Vectorize jax.lax.cond with vmap

Time:12-10

Hi why can't I vectorize the condition function to apply for a list of boolean? or is there something else going on here?

DK = jnp.array([[True],[True],[False],[True]])
f1 = lambda x: 1
f2 = lambda y: 0
cond = lambda dk: jax.lax.cond(dk,f1,f2)
vcond = jax.vmap(cond)
vcond(DK)

I was expecting it to give me an array.

CodePudding user response:

Try this:

import jax
import jax.numpy as jnp
DK = jnp.array([[True],[True],[False],[True]])
f1 = lambda x: 1
f2 = lambda y: 0
cond = lambda dk: jax.lax.cond(dk,
                               dk, lambda x: f1(x),
                               dk, lambda x: f2(x))
vcond = jax.vmap(jax.vmap(cond))
vcond(DK)

Output:

DeviceArray([[1],
             [1],
             [0],
             [1]], dtype=int32, weak_type=True)

CodePudding user response:

There are two issues here: first, lax.cond requires a scalar predicate, and you are vmapping a 2D input, which ends up effectively passing a 1D predicate. You can fix this by using two levels of vmap:

vcond = jax.vmap(jax.vmap(cond))

Second, you've set up your f1 and f2 to take an argument x, but it doesn't appear you have any value x to pass to them (if you did, you could pass it as jax.lax.cond(dk,f1,f2,x)). If no x argument is needed, you can redefine the functions to take no arguments. The resulting code looks like this:

import jax.numpy as jnp
import jax

DK = jnp.array([[True],[True],[False],[True]])
f1 = lambda: 1
f2 = lambda: 0
cond = lambda dk: jax.lax.cond(dk,f1,f2)
vcond = jax.vmap(jax.vmap(cond))
result = vcond(DK)
print(result)
# [[1]
#  [1]
#  [0]
#  [1]]

Note that you can avoid a lot of this complexity by replacing your code with jnp.where. For example:

result = jnp.where(DK, 1, 0)

This will lower to essentially the same XLA select operation as your original code.

  • Related