I am trying to make a custom training code for the TensorFlow Keras model.
As we need to separate train_step
(which updates the model weights) and test_step
(which only makes the inference), I wonder if we can make a single function that can work for both.
My idea is to write something like this.
def _train_or_evaluate(self, inputs, gt1, gt2, is_training=False):
with tf.GradientTape() as tape1 'if is_training': # Enable or disable the with statement by condition
model1_output = self.model1(inputs)
model1_loss = self.loss_obj(gt1, model1_output)
inputs2 = self.process_output(model1_output)
with tf.GradientTape() as tape2 'if is_training': # Enable or disable the with statement by condition
model2_output = self.model2(inputs2)
model2_loss = self.loss_obj2(gt2, model2_output)
if is_training:
model1_gradients = tape1.gradient(model1_loss, self.model1.trainable_variables)
self.optimizer1.apply_gradients(model1_gradients, self.model1.trainable_variables)
model2_gradients = tape2.gradient(model2_loss, self.model2.trainable_variables)
self.optimizer2.apply_gradients(model2_gradients, self.model2.trainable_variables)
return model1_loss, model2_loss
def train_step(self, inputs):
inputs = inputs, (gt1, gt2)
return self._train_or_evaluate(inputs, gt1, gt2, True)
def test_step(self, inputs):
inputs = inputs, (gt1, gt2)
return self._train_or_evaluate(inputs, gt1, gt2, False)
Is there any way to give a condition to enable or disable the with
statement while keeping the block running? So when the condition is False, the block is still running similarly as without the with
statement.
Such that:
When self._train_or_evaluate(inputs, gt1, gt2, True)
it is equivalent as:
def _train(self, inputs, gt1, gt2):
with tf.GradientTape() as tape1:
model1_output = self.model1(inputs)
model1_loss = self.loss_obj(gt1, model1_output)
inputs2 = self.process_output(model1_output)
with tf.GradientTape() as tape2:
model2_output = self.model2(inputs2)
model2_loss = self.loss_obj2(gt2, model2_output)
model1_gradients = tape1.gradient(model1_loss, self.model1.trainable_variables)
self.optimizer1.apply_gradients(model1_gradients, self.model1.trainable_variables)
model2_gradients = tape2.gradient(model2_loss, self.model2.trainable_variables)
self.optimizer2.apply_gradients(model2_gradients, self.model2.trainable_variables)
return model1_loss, model2_loss
And when self._train_or_evaluate(inputs, gt1, gt2, False)
it is equivalent as:
def _evaluate(self, inputs, gt1, gt2):
model1_output = self.model1(inputs)
model1_loss = self.loss_obj(gt1, model1_output)
inputs2 = self.process_output(model1_output)
model2_output = self.model2(inputs2)
model2_loss = self.loss_obj2(gt2, model2_output)
return model1_loss, model2_loss
Thank you.
CodePudding user response:
You can do this through clever use of contextlib.ExitStack
to separate the with
and the creation of the managed object, so the managed object can be created and managed conditionally. For the first conditional with
, you'd replace it with:
with contextlib.ExitStack() as stack:
if is_training:
tape1 = stack.enter_context(tf.GradientTape())
model1_output = self.model1(inputs)
model1_loss = self.loss_obj(gt1, model1_output)
Omit the tape1 =
if you don't actually need it bound to something named tape1
(since you don't, and can't, use tape1
thanks to being conditionally defined, it's probably unnecessary).
When is_training
is falsy, the ExitStack
is basically a no-op; it's created and torn down, and since it's managing nothing, no cleanup is performed. When is_training
is truthy, the GradientTape
instance is created and immediately registered with ExitStack
, so when the with
exits, it's cleaned up normally.
CodePudding user response:
As I stated in the comment above, no matter how one does customize the __enter__
metho of a class, which is the one called by the with
statement, it is not possible for the with
block itself to be skipped.
However, it simply nesting an extra if
block is not desirable (or nesting the with
within an if
, for that matter), it is possible to fully emulate a with
block, by using try/finally
- and then you can use an if
block instead of an with
. It will still take an extra indentation level, though, due to both an if
and a try/finally
block being needed.
def _train_or_evaluate(self, inputs, gt1, gt2, is_training=False):
try:
if is_training:
tape1 = (cm:=tf.GradientTape()).__enter__()
model1_output = self.model1(inputs)
model1_loss = self.loss_obj(gt1, model1_output)
finally:
if is_training:
cm.exit(*sys.exc_info())
...
# second with block and rest of the function
(note that I use the walrus operator ( :=
) to keep a reference to the instance from which the with
statement would call the __.enter__
and __exit__
methods. Ordinarily, in a with
block one won't need to care about this)