Hi why can't I vectorize the condition function to apply for a list of boolean? or is there something else going on here?
DK = jnp.array([[True],[True],[False],[True]])
f1 = lambda x: 1
f2 = lambda y: 0
cond = lambda dk: jax.lax.cond(dk,f1,f2)
vcond = jax.vmap(cond)
vcond(DK)
I was expecting it to give me an array.
CodePudding user response:
Try this:
import jax
import jax.numpy as jnp
DK = jnp.array([[True],[True],[False],[True]])
f1 = lambda x: 1
f2 = lambda y: 0
cond = lambda dk: jax.lax.cond(dk,
dk, lambda x: f1(x),
dk, lambda x: f2(x))
vcond = jax.vmap(jax.vmap(cond))
vcond(DK)
Output:
DeviceArray([[1],
[1],
[0],
[1]], dtype=int32, weak_type=True)
CodePudding user response:
There are two issues here: first, lax.cond
requires a scalar predicate, and you are vmapping a 2D input, which ends up effectively passing a 1D predicate. You can fix this by using two levels of vmap
:
vcond = jax.vmap(jax.vmap(cond))
Second, you've set up your f1
and f2
to take an argument x
, but it doesn't appear you have any value x
to pass to them (if you did, you could pass it as jax.lax.cond(dk,f1,f2,x)
). If no x
argument is needed, you can redefine the functions to take no arguments. The resulting code looks like this:
import jax.numpy as jnp
import jax
DK = jnp.array([[True],[True],[False],[True]])
f1 = lambda: 1
f2 = lambda: 0
cond = lambda dk: jax.lax.cond(dk,f1,f2)
vcond = jax.vmap(jax.vmap(cond))
result = vcond(DK)
print(result)
# [[1]
# [1]
# [0]
# [1]]
Note that you can avoid a lot of this complexity by replacing your code with jnp.where
. For example:
result = jnp.where(DK, 1, 0)
This will lower to essentially the same XLA select
operation as your original code.