Home > Net >  Wrap object in custom Python and then add extra logic
Wrap object in custom Python and then add extra logic

Time:04-16

A Python library provides a function create_object that creates an object of type OriginalClass.

I would like to create my own class so that it takes the output of create_object and adds extra logic (on top of what create_object already does). Also, that new custom object should have all the properties of the base object.

So far I've attempted the following:

class MyClass(OriginalClass):
  
  def __init__(self, *args, **kwargs):
    super(MyClass, self).__init__(args, kwargs)

This does not accomplish what I have in mind. Since the function create_object is not called and the extra logic handled by it not executed.

Also, I do not want to attach the output of create_object to an attribute of MyClass like so self.myobject = create_object(), since I want it to be accessed by just the instantiation of an object of type MyClass.

What would be the best way to achieve that functionality in Python? Does that corresponds to an existing design pattern?

I am new to Python OOP so maybe the description provided is too vague. Please feel free to request in depth description from those vaguely described parts.

CodePudding user response:

Try this:

class MyClass(OriginalClass):
    def __init__(self, custom_arg, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.init(custom_arg)
    def init(self, custom_arg):
        # add subclass initialization logic here
        self._custom_arg = custom_arg
    def my_class_method(self):
        pass

obj = create_object()
obj.__class__ = MyClass
obj.init(custom_arg)

obj.original_class_method()
obj.my_class_method()

You can change the __class__ attribute of an object if you know what you're doing.

If I was you I would consider using an Adapter design pattern. It's maybe longer to code, but it's easier to maintain and understand.

CodePudding user response:

Looking at the original code, I would have implemented the create_object functions as class methods.

class SqueezeNet(nn.Module):
    ...

    @classmethod
    def squeezenet1_0(cls, *args, **kwargs):
        def squeezenet1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SqueezeNet:
        return cls._squeezenet('1_0', pretrained, progress, **kwargs)

    @classmethod
    def squeezenet1_1(cls, *args, **kwargs):
        def squeezenet1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SqueezeNet:
        return cls._squeezenet('1_1', pretrained, progress, **kwargs)

    @classmethod
    def _squeezenet(cls, version: str, pretrained: bool, progress: bool, **kwargs: Any) -> SqueezeNet:
        model = cls(version, **kwargs)
        if pretrained:
            arch = 'squeezenet'   version
            state_dict = load_state_dict_from_url(model_urls[arch],
                                                  progress=progress)
            model.load_state_dict(state_dict)
        return model

So what does the class method do? It just instantiates the object as normal, but then calls a particular method on it before returning it. As such, there's nothing to do in your subclass. Calling MySqueezeNetSubclass._squeezenet would instantiate your subclass, not SqueezeNet. If you need to customize anything else, you can override _squeezenet in your own class, using super()._squeezenet to do the parent creation first before modifying the result.

class MySubclass(SqueezeNet):
    @classmethod
    def _squeezenet(cls, *args, **kwargs):
        model = super()._squeezenet(*args, **kwargs)
        # Do MySubclass-specific things to model here
        return model

But, _squeezenet isn't a class method; it's a regular function. There's not much you can do except patch it at runtime, which is hopefully something you can do before anything tries to call it. For example,

import torchvision.models.squeezenet

def _new_squeezenet(version, pertained, progress, **kwargs):
    model = MySqueezeNetSubClass(version, **kwarsg)
    # Maybe more changes specific to your code here. Specifically,
    # you might want to provide your own URL rather than one from
    # model_urls, etc.
    if pretrained:
        arch = 'squeezenet'   version
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model


torchvision.models.squeezenet._squeezenet = _new_squeezenet

The lesson here is that not everything is designed to be easily subclassed.

  • Related