Home > Mobile >  I don't know how to do automatic differentiation of a neural network with two outputs
I don't know how to do automatic differentiation of a neural network with two outputs

Time:12-21

I would like to create a physics-informed neural network (PINN) using JAX. I have created a neural network with one input (x) and two outputs (y1, y2), how do I differentiate y1 by x and y2 by x? I know that I can differentiate y by x in a neural network with one input (x) and one output (y) by using the following method, but I do not know how to do it in a NN with two outputs.

jax.grad(NN_model)

CodePudding user response:

jax.grad is designed for differentiating functions that output a single scalar. For more general functions, you may be able to compute what you have in mind using jax.jacobian.

  • Related