Home > Software engineering >  JAX: JIT compatible sparse matrix slicing
JAX: JIT compatible sparse matrix slicing

Time:12-30

I have a boolean sparse matrix that I represent with row indices and column indices of True values.

import numpy as np
import jax
from jax import numpy as jnp
N = 10000
M = 1000
X = np.random.randint(0, 100, size=(N, M)) == 0  # data setup
rows, cols = np.where(X == True)
rows = jax.device_put(rows)
cols = jax.device_put(cols)

I want to get a column slice of the matrix like X[:, 3], but just from rows indices and column indices.

I managed to do that by using jnp.isin like below, but the problem is that this is not JIT compatible because of the data-dependent shaped array rows[cols == m].

def not_jit_compatible_slice(rows, cols, m):
  return jnp.isin(jnp.arange(N), rows[cols == m])

I could make it JIT compatible by using jnp.where in the three-argument form, but this operation is much slower than the previous one.

def jit_compatible_but_slow_slice(rows, cols, m):
  return jnp.isin(jnp.arange(N), jnp.where(cols == m, rows, -1))

Is there any fast and JIT compatible solution to acheive the same output?

CodePudding user response:

I figured out that the implementation below returns the same output much faster, and it’s JIT compatible.


def slice(rows, cols, m):
  res = jnp.zeros(N   1, style=bool)
  res = res.at[jnp.where(cols == m, rows, -1)].set(True)
  return res[:-1]

CodePudding user response:

You can do a bit better than the first answer by using the mode argument of set() to drop out-of-bound indices, eliminating the final slice:

out = jnp.zeros(N, bool).at[jnp.where(cols==3, rows, N)].set(True, mode='drop')
  • Related