Let f: R -> R
be an infinitely differentiable function. What is the computational complexity of calculating the first n
derivatives of f
in Jax? Naive chain rule would suggest that each multiplication gives a factor of 2 increase, hence the nth derivative would require at least 2^n
more operations. I imagine though that clever manipulation of formal series would reduce the number of required calculations and eliminate duplications, esspecially if the derivaives are Jax jitted? Is there a different between the Jax, Tensorflow and Torch implementations?
https://openreview.net/forum?id=SkxEF3FNPH discusses this topic, but doesn t provide a computational complexity.
CodePudding user response:
What is the computational complexity of calculating the first
n
derivatives off
in Jax?
There's not much you can say in general about computational complexity of Nth derivatives. For example, with a function like jnp.sin
, the Nth derivative is O[1]
, oscillating between negative and positive sin
and cos
calls as N grows. For an order-k polynomial, the Nth derivative is O[0]
for N > k. Other functions may have complexity that is linear or polynomial or even exponential with N
depending on the operations they contain.
I imagine though that clever manipulation of formal series would reduce the number of required calculations and eliminate duplications, esspecially if the derivaives are Jax jitted
You imagine correctly! One implementation of this idea is the jax.experimental.jet
module, which is an experimental transform designed for computing higher-order derivatives efficiently and accurately. It doesn't cover all JAX functions, but it may be complete enough to do what you have in mind.