Home > Net >  How to override a context manager correctly?
How to override a context manager correctly?

Time:11-05

this is a follow-up question of this one.

I was trying to make a new context manager to occasionally turn off an existing context manager. This example redefine a new context manager so that the overrided one doesn't do anything within the new conext manager no_autocast. But I hit the error: 'NoneType' object is not callable when calling the redefined one: with autocast("cuda", dtype=torch.float16):. Any idea how to fix it?

import torch
from torch import autocast
print(autocast)
class no_autocast():
    def __init__(self):
        self.autocast = None

    def __enter__(self, *args, **kwargs):
        global autocast
        self.autocast = autocast
        autocast = no_autocast.do_nothing(*args, **kwargs)

    def __exit__(self, exc_type, exc_value, exc_traceback):
        global autocast
        autocast = self.autocast
   
    def do_nothing(*args, **kwargs):
        pass


with no_autocast():
    print("outer 0:", torch.is_autocast_enabled())
    #torch.set_autocast_enabled(False)
    print("outer 1:", torch.is_autocast_enabled())
    with autocast("cuda", dtype=torch.float16):
        def func(a, b): return a @ b
        print("inner ", torch.is_autocast_enabled())

CodePudding user response:

You probably should not use global variables for this purpose. With that out of the way, the problem is this line:

autocast = no_autocast.do_nothing(*args, **kwargs)

You call the do_nothing method, which returns None. So you set autocast to None.

You probably mean to do:

autocast = no_autocast.do_nothing

Which replaces autocast with the do_nothing method. But you should assign it a context manager, not just a do-nothing method. An example of a simple do-nothing context manager would be

import contextlib

@contextlib.contextmanager
def do_nothing(*args, **kwargs):
    yield
  • Related