I would like to create a physics-informed neural network (PINN) using JAX. I have created a neural network with one input (x) and two outputs (y1, y2), how do I differentiate y1 by x and y2 by x? I know that I can differentiate y by x in a neural network with one input (x) and one output (y) by using the following method, but I do not know how to do it in a NN with two outputs.
jax.grad(NN_model)
CodePudding user response:
jax.grad
is designed for differentiating functions that output a single scalar. For more general functions, you may be able to compute what you have in mind using jax.jacobian
.