Was taking a look at the dropout implementation in flax:
def __call__(self, inputs, deterministic: Optional[bool] = None):
"""Applies a random dropout mask to the input.
Args:
inputs: the inputs that should be randomly masked.
deterministic: if false the inputs are scaled by `1 / (1 - rate)` and
masked, whereas if true, no mask is applied and the inputs are returned
as is.
Returns:
The masked inputs reweighted to preserve mean.
"""
deterministic = merge_param(
'deterministic', self.deterministic, deterministic)
if (self.rate == 0.) or deterministic:
return inputs
# Prevent gradient NaNs in 1.0 edge-case.
if self.rate == 1.0:
return jnp.zeros_like(inputs)
keep_prob = 1. - self.rate
rng = self.make_rng(self.rng_collection)
broadcast_shape = list(inputs.shape)
for dim in self.broadcast_dims:
broadcast_shape[dim] = 1
mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
mask = jnp.broadcast_to(mask, inputs.shape)
return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs))
Particularly, I'm interested in last line lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs))
. Wondering why lax.select
is used here instead of:
return jnp.where(mask, inputs / keep_prob, 0)
or even more simply:
return mask * inputs / keep_prob
CodePudding user response:
jnp.where
is basically the same as lax.select
, except more flexible in its inputs: for example, it will broadcast inputs to the same shape or cast to the same dtype, whereas lax.select
requires more strict matching of inputs:
>>> import jax.numpy as jnp
>>> from jax import lax
>>> x = jnp.arange(3)
# Implicit broadcasting
>>> jnp.where(x < 2, x[:, None], 0)
DeviceArray([[0, 0, 0],
[1, 1, 0],
[2, 2, 0]], dtype=int32)
>>> lax.select(x < 2, x[:, None], 0)
TypeError: select cases must have the same shapes, got [(), (3, 1)].
# Implicit type promotion
>>> jnp.where(x < 2, jnp.zeros(3), jnp.arange(3))
DeviceArray([0., 0., 2.], dtype=float32)
>>> lax.select(x < 2, jnp.zeros(3), jnp.arange(3))
TypeError: lax.select requires arguments to have the same dtypes, got int32, float32. (Tip: jnp.where is a similar function that does automatic type promotion on inputs).
Library code is one place where the stricter semantics can be useful, because rather than smoothing-over potential implementation bugs and returning an unexpected output, it will complain loudly. But performance-wise (especially once JIT-compiled) the two are essentially equivalent.
As for why the flax developers chose lax.select
vs. multiplying by a mask, I can think of two reasons:
- Multiplying by a mask is subject to implicit type promotion semantics, and it takes a lot more thought to anticipate problematic outputs than a simple
select
, which is specifically-designed for the intended operation. - Using multiplication causes the compiler to treat this operation as a multiplication, which it is not. A
select
is a much more narrow and precise operation than a multiplication, and by specifying operations precisely it often allows the compiler to optimize the results to a greater extent.