jax.nn.softmax
is defined as:
def softmax(x: Array,
axis: Optional[Union[int, Tuple[int, ...]]] = -1,
where: Optional[Array] = None,
initial: Optional[Array] = None) -> Array:
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
unnormalized = jnp.exp(x - lax.stop_gradient(x_max))
return unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
I'm particularly interested in the lax.stop_gradient(x_max)
part. I would love an explanation for why it's needed. From a practical standpoint, it seems that stop_gradient
doesn't change the gradient calculation:
import jax
import jax.numpy as jnp
def softmax_unstable(x):
return jnp.exp(x) / jnp.sum(jnp.exp(x))
def softmax_stable(x):
x = x - jnp.max(x)
return jnp.exp(x) / jnp.sum(jnp.exp(x))
def softmax_stop_gradient(x):
x = x - jax.lax.stop_gradient(jnp.max(x))
return jnp.exp(x) / jnp.sum(jnp.exp(x))
# example input
x = jax.random.normal(jax.random.PRNGKey(123), (100,))
# make sure all forward passes are equal
a = softmax_unstable(x)
b = softmax_stable(x)
c = softmax_stop_gradient(x)
d = jax.nn.softmax(x)
assert jnp.allclose(a, b) and jnp.allclose(b, c) and jnp.allclose(c, d)
# make sure all gradient calculations are the same
a = jax.grad(lambda x: -jnp.log(softmax_unstable(x))[2])(x)
b = jax.grad(lambda x: -jnp.log(softmax_stable(x))[2])(x)
c = jax.grad(lambda x: -jnp.log(softmax_stop_gradient(x))[2])(x)
d = jax.grad(lambda x: -jnp.log(jax.nn.softmax(x))[2])(x)
assert jnp.allclose(a, b) and jnp.allclose(b, c) and jnp.allclose(c, d)
# make sure all gradient calculations are the same, this time we use softmax functions twice
a = jax.grad(lambda x: -jnp.log(softmax_unstable(softmax_unstable(x)))[2])(x)
b = jax.grad(lambda x: -jnp.log(softmax_stable(softmax_stable(x)))[2])(x)
c = jax.grad(lambda x: -jnp.log(softmax_stop_gradient(softmax_stop_gradient(x)))[2])(x)
d = jax.grad(lambda x: -jnp.log(jax.nn.softmax(jax.nn.softmax(x)))[2])(x)
assert jnp.allclose(a, b) and jnp.allclose(b, c) and jnp.allclose(c, d)
^ all implementations are equal, even the one where we apply the x - x_max
trick but WITHOUT stop_gradient
.
CodePudding user response:
First off, the reason for subtracting x_max
at all is because it prevents overflow for large inputs. For example:
x = jnp.array([1, 2, 1000])
print(softmax_unstable(x))
# [ 0. 0. nan]
print(softmax_stable(x))
# [0. 0. 1.]
print(softmax_stop_gradient(x))
# [0. 0. 1.]
As for why we use stop_gradient
here, we can show analytically that the max(x)
term cancels-out in the gradient computation, and so we know a priori that its gradient cannot affect the gradient of the overall function. Marking it as stop_gradient
communicates this to JAX's autodiff machinery, leading to a more efficient gradient computation. You can see this efficiency in action by printing the jaxpr for each version of the gradient function:
x = jnp.float32(1)
print(jax.make_jaxpr(jax.grad(softmax_stable))(x))
{ lambda ; a:f32[]. let
b:f32[] = reduce_max[axes=()] a
c:f32[] = reshape[dimensions=None new_sizes=()] b
d:bool[] = eq a c
e:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
f:f32[] = reduce_sum[axes=()] e
g:f32[] = sub a b
h:f32[] = exp g
i:f32[] = exp g
j:f32[] = reduce_sum[axes=()] i
_:f32[] = div h j
k:f32[] = integer_pow[y=-2] j
l:f32[] = mul 1.0 k
m:f32[] = mul l h
n:f32[] = neg m
o:f32[] = div 1.0 j
p:f32[] = mul n i
q:f32[] = mul o h
r:f32[] = add_any p q
s:f32[] = neg r
t:f32[] = div s f
u:f32[] = mul t e
v:f32[] = add_any r u
in (v,) }
print(jax.make_jaxpr(jax.grad(softmax_stop_gradient))(x))
{ lambda ; a:f32[]. let
b:f32[] = reduce_max[axes=()] a
c:f32[] = reshape[dimensions=None new_sizes=()] b
d:bool[] = eq a c
e:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
_:f32[] = reduce_sum[axes=()] e
f:f32[] = stop_gradient b
g:f32[] = sub a f
h:f32[] = exp g
i:f32[] = exp g
j:f32[] = reduce_sum[axes=()] i
_:f32[] = div h j
k:f32[] = integer_pow[y=-2] j
l:f32[] = mul 1.0 k
m:f32[] = mul l h
n:f32[] = neg m
o:f32[] = div 1.0 j
p:f32[] = mul n i
q:f32[] = mul o h
r:f32[] = add_any p q
in (r,) }
The second version requires fewer computations to achieve the same result, because we've essentially told the autodiff machinery it does not have to worry about differentiating max(x)
.
CodePudding user response:
That is a very good question! As you already observed subtraction of maximum does not really affect softmax, and thus gradients are the same. So why stop the gradient? Well... the answer is pretty trivial - it saves compute. This way jax does not have to trace back gradients flowing through the max computation, which eventually would cancel itself out. Shorter compilation time, less compute, less strain on graph optimizer, and less chances of some weird numerical errors to kick in.
So it is not needed, but it is beneficial :)