Home > Enterprise >  computational complexity of higher order derivatives with AD in jax
computational complexity of higher order derivatives with AD in jax

Time:02-03

Let f: R -> R be an infinitely differentiable function. What is the computational complexity of calculating the first n derivatives of f in Jax? Naive chain rule would suggest that each multiplication gives a factor of 2 increase, hence the nth derivative would require at least 2^n more operations. I imagine though that clever manipulation of formal series would reduce the number of required calculations and eliminate duplications, esspecially if the derivaives are Jax jitted? Is there a different between the Jax, Tensorflow and Torch implementations?

https://openreview.net/forum?id=SkxEF3FNPH discusses this topic, but doesn t provide a computational complexity.

CodePudding user response:

What is the computational complexity of calculating the first n derivatives of f in Jax?

There's not much you can say in general about computational complexity of Nth derivatives. For example, with a function like jnp.sin, the Nth derivative is O[1], oscillating between negative and positive sin and cos calls as N grows. For an order-k polynomial, the Nth derivative is O[0] for N > k. Other functions may have complexity that is linear or polynomial or even exponential with N depending on the operations they contain.

I imagine though that clever manipulation of formal series would reduce the number of required calculations and eliminate duplications, esspecially if the derivaives are Jax jitted

You imagine correctly! One implementation of this idea is the jax.experimental.jet module, which is an experimental transform designed for computing higher-order derivatives efficiently and accurately. It doesn't cover all JAX functions, but it may be complete enough to do what you have in mind.

  • Related