Say I have a 2-layer neural network and I want to make the second layer non trainable. So I initiate these variables
w1 = tf.random.truncated_normal([28*28, 256])
b1 = tf.zeros([256])
w2 = tf.random.truncated_normal([256, 50])
b2 = tf.zeros([10])
and train them.
for (x,y) in db:
x = tf.reshape(x, [-1, 28*28])
with tf.GradientTape() as tape:
h1 = x@w1 tf.broadcast_to(b1, [x.shape[0], 256])
h1 = tf.nn.relu(h1)
h2 = h1@w2 tf.broadcast_to(b2, [x.shape[0], 10])
out = tf.nn.relu(h2)
y_onehot = tf.one_hot(y, depth=10)
loss = tf.square(y_onehot - out)
loss = tf.reduce_mean(loss)
The problem is, with GradientTape() enclosing all variables, w2 and b2 are also trainable. How to make them non trainable?
CodePudding user response:
Your assumption that all variables inside the GradientTape
will be trainable is incorrect.
Gradients are only computed for the variables that you pass to the gradient
function as second parameter:
tape.gradient(tensor, <your trainable variables here>).