I was observing gradients when I noticed that the gradient of subtracting one's axis' mean is zero. I think this is very counter-intuitive because gradient = 0 normally means the function is constant. Can anyone explain intuitively why the gradient here is zero?
import tensorflow as tf
o1 = tf.random.normal((3, 3, 3, 3))
with tf.GradientTape() as tape:
tape.watch(o1)
o2 = o1-tf.reduce_mean(o1, 1, keepdims=True)
d = tape.gradient(o2, o1)
tf.print(tf.reduce_max(tf.abs(d)))
outputs me 0
CodePudding user response:
The issue is that tape.gradient
, when passed a tensor, will first compute the sum of the tensor and then compute the gradient of the resulting scalar. That is, tape.gradient
only computes gradients of scalar functions.
Now, since you subtract the mean off of o1
, the mean (and thus the sum) of the output will always be 0. It doesn't matter how o1
is changed, you are always subtracting the mean, and so the output will never change from 0, and thus you get a gradient of 0.
Note: GradientTape
has a jacobian
function which computes a full Jacobian matrix and does not require scalar outputs.