I'm trying to write a custom training loop. Here is a sample code of what I'm trying to do. I have two training parameter and one parameter is updating another parameter. See the code below:
x1 = tf.Variable(1.0, dtype=float)
x2 = tf.Variable(1.0, dtype=float)
with tf.GradientTape() as tape:
n = x2 4
x1.assign(n)
x = x1 1
y = x**2
val = tape.gradient(y, [x1, x2])
for v in val:
print(v)
and the output is
tf.Tensor(12.0, shape=(), dtype=float32)
None
It seems like GradientTape is not watching the first(x2) parameter. Both parameter is tf.Variable
type, so GradientTape should watch both the parameter. I also tried tape.watch(x2)
, which is also not working. Am I missing something?
CodePudding user response:
Check the docs regarding a gradient of None
. To get the gradients for x1
, you have to track x
with tape.watch(x)
:
x1 = tf.Variable(1.0, dtype=float)
x2 = tf.Variable(1.0, dtype=float)
with tf.GradientTape() as tape:
n = x2 4
x1.assign(n)
x = x1 1
tape.watch(x)
y = x**2
dv0, dv1 = tape.gradient(y, [x1, x2])
print(dv0)
print(dv1)
However, regarding x2
, the output y
is not connected to x2
at all, since x1.assign(n)
does not seem to be tracked and that is why the gradient is None. This is consistent with the docs:
State stops gradients. When you read from a stateful object, the tape can only observe the current state, not the history that lead to it.
A tf.Tensor is immutable. You can't change a tensor once it's created. It has a value, but no state. All the operations discussed so far are also stateless: the output of a tf.matmul only depends on its inputs.
A tf.Variable has internal state—its value. When you use the variable, the state is read. It's normal to calculate a gradient with respect to a variable, but the variable's state blocks gradient calculations from going farther back
If, for example, you do something like this:
x1 = tf.Variable(1.0, dtype=float)
x2 = tf.Variable(1.0, dtype=float)
with tf.GradientTape() as tape:
n = x2 4
x1 = n
x = x1 1
tape.watch(x)
y = x**2
dv0, dv1 = tape.gradient(y, [x1, x2])
It should work.