Home > Software design >  How to check if a value is in an array while using jax
How to check if a value is in an array while using jax

Time:04-05

I have a negative sampling function that I want to use JAX's @jit but everything that I do makes it stop working.

The parameters are:

  • key: key to jax.random
  • ratings: a list of 3-tuples (user_id, item_id, 1);
  • user_positives: a list of lists where the i-th inner list contains the items that the i-th user has consumed;
  • num_items: the total number of items

My function is shown below, and its goal is to draw 100 samples from ratings and, for each sample, retrieve an item that has not been consumed by that user.

BATCH_SIZE = 100

@jit
def sample(key, ratings, user_positives, num_items):
    new_key, subkey = jax.random.split(key)
    
    sampled_ratings = jax.random.choice(subkey, ratings, shape=(BATCH_SIZE,))
    sampled_users = jnp.zeros(BATCH_SIZE)
    sampled_positives = jnp.zeros(BATCH_SIZE)
    sampled_negatives = jnp.zeros(BATCH_SIZE)
    idx = 0
    
    for u, i, r in sampled_ratings:
        negative = user_positives[u][0]
        new_key, subkey = jax.random.split(key)
        while jnp.isin(jnp.array([negative]), user_positives[u])[0]:
            negative = jax.random.randint(current_subkey, (1,), 0, num_items)
            current_subkey = jax.random.split(subkey)
        
        sampled_users.at[idx].set(u)
        sampled_positives.at[idx].set(i)
        sampled_negatives.at[idx].set(negative)
        idx  = 1
    
    return new_key, sampled_users, sampled_positives, sampled_negatives

However, whenever I run and try to change it, new errors are generated, and I got stuck in the error below. Can anyone help me do this?

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
While tracing the function sample at /tmp/ipykernel_11557/2294038851.py:1 for jit, this concrete value was not available in Python because it depends on the values of the arguments 'key', 'ratings', and 'user_positives'.

Edit 1: An input example would be:

rng_key = 
rng_key, su, sp, sn = sample(
    rng_key,
    np.array([(0, 0, 1), (0, 1, 1), (1, 2, 1)]),
    np.array([np.array([0, 1]), np.array([2])]),
    15
)

CodePudding user response:

In general if you want to jit-compile a while loop whose condition depends on non-static quantities, you'll have to express it in terms of jax.lax.while_loop. For more information see JAX Sharp Bits: Structured control flow primitives.

I'll try to edit my answer with an example based on your code if you can add an example of the expected input.

  • Related