I have a boolean sparse matrix that I represent with row indices and column indices of True
values.
import numpy as np
import jax
from jax import numpy as jnp
N = 10000
M = 1000
X = np.random.randint(0, 100, size=(N, M)) == 0 # data setup
rows, cols = np.where(X == True)
rows = jax.device_put(rows)
cols = jax.device_put(cols)
I want to get a column slice of the matrix like X[:, 3]
, but just from rows indices and column indices.
I managed to do that by using jnp.isin
like below, but the problem is that this is not JIT compatible because of the data-dependent shaped array rows[cols == m]
.
def not_jit_compatible_slice(rows, cols, m):
return jnp.isin(jnp.arange(N), rows[cols == m])
I could make it JIT compatible by using jnp.where
in the three-argument form, but this operation is much slower than the previous one.
def jit_compatible_but_slow_slice(rows, cols, m):
return jnp.isin(jnp.arange(N), jnp.where(cols == m, rows, -1))
Is there any fast and JIT compatible solution to acheive the same output?
CodePudding user response:
I figured out that the implementation below returns the same output much faster, and it’s JIT compatible.
def slice(rows, cols, m):
res = jnp.zeros(N 1, style=bool)
res = res.at[jnp.where(cols == m, rows, -1)].set(True)
return res[:-1]
CodePudding user response:
You can do a bit better than the first answer by using the mode
argument of set()
to drop out-of-bound indices, eliminating the final slice:
out = jnp.zeros(N, bool).at[jnp.where(cols==3, rows, N)].set(True, mode='drop')