Home > other >  Vectorise nested vmap
Vectorise nested vmap

Time:11-05

Here's some data I have:

import jax.numpy as jnp
import numpyro.distributions as dist
import jax

xaxis = jnp.linspace(-3, 3, 5)
yaxis = jnp.linspace(-3, 3, 5)

I'd like to run the function

def func(x, y):
    return dist.MultivariateNormal(jnp.zeros(2), jnp.array([[.5, .2], [.2, .1]])).log_prob(jnp.asarray([x, y]))

over each pair of values from xaxis and yaxis.

Here's a "slow" way to do:

results = np.zeros((len(xaxis), len(yaxis)))

for i in range(len(xaxis)):
    for j in range(len(yaxis)):
        results[i, j] = func(xaxis[i], yaxis[j])

Works, but it's slow.

So here's a vectorised way of doing it:

jax.vmap(lambda axis: jax.vmap(func, (None, 0))(axis, yaxis))(xaxis)

Much faster, but it's hard to read.

Is there a clean way of writing the vectorised version? Can I do it with a single vmap, rather than having to nest one within another one?

EDIT

Another way would be

jax.vmap(func)(xmesh.flatten(), ymesh.flatten()).reshape(len(xaxis), len(yaxis)).T

but it's still messy.

CodePudding user response:

I believe Vectorization guidelnes for jax is quite similar to your question; to replicate the logic of nested for-loops with vmap requires nested vmaps.

The cleanest approach using jax.vmap is probably something like this:

from functools import partial

@partial(jax.vmap, in_axes=(0, None))
@partial(jax.vmap, in_axes=(None, 0))
def func(x, y):
    return dist.MultivariateNormal(jnp.zeros(2), jnp.array([[.5, .2], [.2, .1]])).log_prob(jnp.asarray([x, y]))

func(xaxis, yaxis)

Another option here is to use the jnp.vectorize API (which is implemented via multiple vmaps), in which case you can do something like this:

print(jnp.vectorize(func)(xaxis[:, None], yaxis))
  • Related