Here's some data I have:
import jax.numpy as jnp
import numpyro.distributions as dist
import jax
xaxis = jnp.linspace(-3, 3, 5)
yaxis = jnp.linspace(-3, 3, 5)
I'd like to run the function
def func(x, y):
return dist.MultivariateNormal(jnp.zeros(2), jnp.array([[.5, .2], [.2, .1]])).log_prob(jnp.asarray([x, y]))
over each pair of values from xaxis
and yaxis
.
Here's a "slow" way to do:
results = np.zeros((len(xaxis), len(yaxis)))
for i in range(len(xaxis)):
for j in range(len(yaxis)):
results[i, j] = func(xaxis[i], yaxis[j])
Works, but it's slow.
So here's a vectorised way of doing it:
jax.vmap(lambda axis: jax.vmap(func, (None, 0))(axis, yaxis))(xaxis)
Much faster, but it's hard to read.
Is there a clean way of writing the vectorised version? Can I do it with a single vmap
, rather than having to nest one within another one?
EDIT
Another way would be
jax.vmap(func)(xmesh.flatten(), ymesh.flatten()).reshape(len(xaxis), len(yaxis)).T
but it's still messy.
CodePudding user response:
I believe Vectorization guidelnes for jax is quite similar to your question; to replicate the logic of nested for-loops with vmap requires nested vmaps.
The cleanest approach using jax.vmap
is probably something like this:
from functools import partial
@partial(jax.vmap, in_axes=(0, None))
@partial(jax.vmap, in_axes=(None, 0))
def func(x, y):
return dist.MultivariateNormal(jnp.zeros(2), jnp.array([[.5, .2], [.2, .1]])).log_prob(jnp.asarray([x, y]))
func(xaxis, yaxis)
Another option here is to use the jnp.vectorize
API (which is implemented via multiple vmaps), in which case you can do something like this:
print(jnp.vectorize(func)(xaxis[:, None], yaxis))