Home > OS >  JAX: Getting rid of zero-gradient
JAX: Getting rid of zero-gradient

Time:08-30

Is there a way how to modify this function (MyFunc) so that it gives the same result, but its derivative is not zero gradient?

from jax import grad
import jax.nn as nn
import numpy as np

def MyFunc(coefs):
   a = coefs[0]
   b = coefs[1]
   c = coefs[2]
   
   if a > b:
      return 30.0
   elif b > c:
      return 20.0
   else:
      return 10.0   
   
myFuncDeriv = grad (MyFunc)   

# prints [0. 0. 0.]
print (myFuncDeriv(np.random.sample(3)))
# prints [0. 0. 0.]
print (myFuncDeriv(np.array([1.0, 2.0, 3.0])))

EDIT: Similar function which doesn't give zero gradient - but it doesn't return 30/20/10

def MyFunc2(coefs):
    a = coefs[0]
    b = coefs[1]
    c = coefs[2]
    if a > b:
        return nn.sigmoid(a)*30.0
    if b > c:
        return nn.sigmoid(b)*20.0
    else:
        return nn.sigmoid(c)*10.0


myFunc2Deriv = grad (MyFunc2)   

# prints [0.         0.         0.45176652]
print (myFuncDeriv(np.array([1.0, 2.0, 3.0])))
# prints for example [6.1160526 0.        0.       ]
print (myFunc2Deriv(np.random.sample(3)))

CodePudding user response:

The gradient of your function is zero because this is the correct result for the gradient as your function is defined. For more information on this phenomenon, see FAQ: Why are gradients zero for functions based on sort order?

If you want a sort-based function with non-zero gradients, you can achieve this by replacing your step-wise function with a smooth approximation. The sigmoid version you included in your question seems like a reasonable approach for this approximation.

But note that the answer to your exact question – how to make a function that produces the same output but has nonzero gradients – is impossible, because a function returning the same outputs as yours for all inputs has a zero gradient by definition.

  • Related