Home > Software design >  Is it possible to give a condition to enable or disable python with statement?
Is it possible to give a condition to enable or disable python with statement?

Time:12-14

I am trying to make a custom training code for the TensorFlow Keras model.

As we need to separate train_step (which updates the model weights) and test_step (which only makes the inference), I wonder if we can make a single function that can work for both.

My idea is to write something like this.

def _train_or_evaluate(self, inputs, gt1, gt2, is_training=False):
    with tf.GradientTape() as tape1 'if is_training': # Enable or disable the with statement by condition
        model1_output = self.model1(inputs)
        model1_loss = self.loss_obj(gt1, model1_output)

    inputs2 = self.process_output(model1_output)

    with tf.GradientTape() as tape2 'if is_training': # Enable or disable the with statement by condition
        model2_output = self.model2(inputs2)
        model2_loss = self.loss_obj2(gt2, model2_output)

    if is_training:
        model1_gradients = tape1.gradient(model1_loss, self.model1.trainable_variables)
        self.optimizer1.apply_gradients(model1_gradients, self.model1.trainable_variables)

        model2_gradients = tape2.gradient(model2_loss, self.model2.trainable_variables)
        self.optimizer2.apply_gradients(model2_gradients, self.model2.trainable_variables)

    return model1_loss, model2_loss

def train_step(self, inputs):
    inputs = inputs, (gt1, gt2)
    return self._train_or_evaluate(inputs, gt1, gt2, True)

def test_step(self, inputs):
    inputs = inputs, (gt1, gt2)
    return self._train_or_evaluate(inputs, gt1, gt2, False)

Is there any way to give a condition to enable or disable the with statement while keeping the block running? So when the condition is False, the block is still running similarly as without the with statement.

Such that:

When self._train_or_evaluate(inputs, gt1, gt2, True) it is equivalent as:

def _train(self, inputs, gt1, gt2):
    with tf.GradientTape() as tape1:
        model1_output = self.model1(inputs)
        model1_loss = self.loss_obj(gt1, model1_output)

    inputs2 = self.process_output(model1_output)

    with tf.GradientTape() as tape2:
        model2_output = self.model2(inputs2)
        model2_loss = self.loss_obj2(gt2, model2_output)

    model1_gradients = tape1.gradient(model1_loss, self.model1.trainable_variables)
    self.optimizer1.apply_gradients(model1_gradients, self.model1.trainable_variables)

    model2_gradients = tape2.gradient(model2_loss, self.model2.trainable_variables)
    self.optimizer2.apply_gradients(model2_gradients, self.model2.trainable_variables)

    return model1_loss, model2_loss

And when self._train_or_evaluate(inputs, gt1, gt2, False) it is equivalent as:

def _evaluate(self, inputs, gt1, gt2):
    model1_output = self.model1(inputs)
    model1_loss = self.loss_obj(gt1, model1_output)

    inputs2 = self.process_output(model1_output)

    model2_output = self.model2(inputs2)
    model2_loss = self.loss_obj2(gt2, model2_output)

    return model1_loss, model2_loss

Thank you.

CodePudding user response:

You can do this through clever use of contextlib.ExitStack to separate the with and the creation of the managed object, so the managed object can be created and managed conditionally. For the first conditional with, you'd replace it with:

with contextlib.ExitStack() as stack:
    if is_training:
        tape1 = stack.enter_context(tf.GradientTape())
    model1_output = self.model1(inputs)
    model1_loss = self.loss_obj(gt1, model1_output)

Omit the tape1 = if you don't actually need it bound to something named tape1 (since you don't, and can't, use tape1 thanks to being conditionally defined, it's probably unnecessary).

When is_training is falsy, the ExitStack is basically a no-op; it's created and torn down, and since it's managing nothing, no cleanup is performed. When is_training is truthy, the GradientTape instance is created and immediately registered with ExitStack, so when the with exits, it's cleaned up normally.

CodePudding user response:

As I stated in the comment above, no matter how one does customize the __enter__ metho of a class, which is the one called by the with statement, it is not possible for the with block itself to be skipped.

However, it simply nesting an extra if block is not desirable (or nesting the with within an if, for that matter), it is possible to fully emulate a with block, by using try/finally - and then you can use an if block instead of an with. It will still take an extra indentation level, though, due to both an if and a try/finally block being needed.

def _train_or_evaluate(self, inputs, gt1, gt2, is_training=False):
    try:
        if is_training:
            tape1 = (cm:=tf.GradientTape()).__enter__()
        model1_output = self.model1(inputs)
        model1_loss = self.loss_obj(gt1, model1_output)
    finally:
        if is_training:
            cm.exit(*sys.exc_info())
    ...
    # second with block and rest of the function

(note that I use the walrus operator ( :=) to keep a reference to the instance from which the with statement would call the __.enter__ and __exit__ methods. Ordinarily, in a with block one won't need to care about this)

  • Related