Home > Mobile >  jax.lax.select vs jax.numpy.where
jax.lax.select vs jax.numpy.where

Time:01-03

Was taking a look at the dropout implementation in flax:

def __call__(self, inputs, deterministic: Optional[bool] = None):
    """Applies a random dropout mask to the input.

    Args:
      inputs: the inputs that should be randomly masked.
      deterministic: if false the inputs are scaled by `1 / (1 - rate)` and
        masked, whereas if true, no mask is applied and the inputs are returned
        as is.

    Returns:
      The masked inputs reweighted to preserve mean.
    """
    deterministic = merge_param(
        'deterministic', self.deterministic, deterministic)

    if (self.rate == 0.) or deterministic:
      return inputs

    # Prevent gradient NaNs in 1.0 edge-case.
    if self.rate == 1.0:
      return jnp.zeros_like(inputs)

    keep_prob = 1. - self.rate
    rng = self.make_rng(self.rng_collection)
    broadcast_shape = list(inputs.shape)
    for dim in self.broadcast_dims:
      broadcast_shape[dim] = 1
    mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
    mask = jnp.broadcast_to(mask, inputs.shape)
    return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs))

Particularly, I'm interested in last line lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs)). Wondering why lax.select is used here instead of:

return jnp.where(mask, inputs / keep_prob, 0)

or even more simply:

return mask * inputs / keep_prob

CodePudding user response:

jnp.where is basically the same as lax.select, except more flexible in its inputs: for example, it will broadcast inputs to the same shape or cast to the same dtype, whereas lax.select requires more strict matching of inputs:

>>> import jax.numpy as jnp
>>> from jax import lax
>>> x = jnp.arange(3)
# Implicit broadcasting
>>> jnp.where(x < 2, x[:, None], 0)
DeviceArray([[0, 0, 0],
             [1, 1, 0],
             [2, 2, 0]], dtype=int32)

>>> lax.select(x < 2, x[:, None], 0)
TypeError: select cases must have the same shapes, got [(), (3, 1)].
# Implicit type promotion
>>> jnp.where(x < 2, jnp.zeros(3), jnp.arange(3))
DeviceArray([0., 0., 2.], dtype=float32)

>>> lax.select(x < 2, jnp.zeros(3), jnp.arange(3))
TypeError: lax.select requires arguments to have the same dtypes, got int32, float32. (Tip: jnp.where is a similar function that does automatic type promotion on inputs).

Library code is one place where the stricter semantics can be useful, because rather than smoothing-over potential implementation bugs and returning an unexpected output, it will complain loudly. But performance-wise (especially once JIT-compiled) the two are essentially equivalent.

As for why the flax developers chose lax.select vs. multiplying by a mask, I can think of two reasons:

  1. Multiplying by a mask is subject to implicit type promotion semantics, and it takes a lot more thought to anticipate problematic outputs than a simple select, which is specifically-designed for the intended operation.
  2. Using multiplication causes the compiler to treat this operation as a multiplication, which it is not. A select is a much more narrow and precise operation than a multiplication, and by specifying operations precisely it often allows the compiler to optimize the results to a greater extent.
  • Related