I want to use vmap to vectorise this code for performance.
def matrix(dataA, dataB):
return jnp.array([[func(a, b) for b in dataB] for a in dataA])
matrix(data, data)
I tried this:
def f(x, y):
return func(x, y)
mapped = jax.vmap(f)
mapped(data, data)
But this only gives the diagonal entries.
Basically I have a vector data = [1,2,3,4,5]
(example), I want to get a matrix such that each entry (i, j)
of the matrix is f(data[i], data[j])
. Thus, the resulting matrix shape will be (len(data), len(data))
.
CodePudding user response:
jax.vmap
maps across one set of axes at a time. If you want to map across two independent sets of axes, you can do so by nesting two vmap
transformations:
mapped = jax.vmap(jax.vmap(f, in_axes=(None, 0)), in_axes=(0, None))
result = mapped(data, data)