Home > Software engineering >  Can not manage to call child methods in a strategy pattern
Can not manage to call child methods in a strategy pattern

Time:02-04

I am stuck on a strategy pattern implementation.

I would like the child classes methods implementing the strategy pattern to be called, but it seems that it is only the abstract class method that is called.

from abc import abstractmethod

class TranslationStrategy:

    @classmethod
    @abstractmethod
    def translate_in_french(cls, text: str) -> str:
        pass

    @classmethod
    @abstractmethod
    def translate_in_spanish(cls, text: str) -> str:
        pass

    FRENCH = translate_in_french
    SPANISH = translate_in_spanish


class Translator(TranslationStrategy):

    @abstractmethod
    def __init__(self, strategy = TranslationStrategy.FRENCH):
        self.strategy = strategy

    def get_translation(self):
        print(self.strategy("random string"))

class LiteraryTranslator(Translator):

    def __init__(self, strategy = TranslationStrategy.FRENCH):
        super().__init__(strategy)

    def translate_in_french(self, text: str) -> str:
        return "french_literary_translation"

    def translate_in_spanish(self, text: str) -> str:
        return "spanish_literary_translation"


class TechnicalTranslator(Translator):

    def __init__(self, strategy=TranslationStrategy.FRENCH):
        super().__init__(strategy)

    def translate_in_french(self, text: str) -> str:
        return "french_technical_translation"

    def translate_in_spanish(self, text: str) -> str:
        return "spanish_technical_translation"

translator = TechnicalTranslator(TranslationStrategy.FRENCH)
translator.get_translation() # prints None, I expect "french_technical_translation"

Am I missusing the strategy pattern here ?

CodePudding user response:

I am not familiar with the strategy pattern, but to make your code work, you could use something like the following:

from abc import ABC, abstractmethod
from enum import Enum


class TranslationStrategy(str, Enum):
    FRENCH = 'french'
    SPANISH = 'spanish'
    
    @classmethod
    def _missing_(cls, value):
        if isinstance(value, str):
            try:
                return cls._member_map_[value.upper()]
            except KeyError:
                pass
        return super()._missing_(value)


class Translator(ABC):
    def __init__(self, strategy=TranslationStrategy.FRENCH):
        self._strategy = TranslationStrategy(strategy)
        self.strategy = getattr(self, f'translate_in_{self._strategy}')

    @abstractmethod
    def translate_in_french(cls, text: str) -> str:
        pass

    @abstractmethod
    def translate_in_spanish(cls, text: str) -> str:
        pass
    
    def get_translation(self, text: str = 'random string'):
        print(self.strategy(text))


class LiteraryTranslator(Translator):
    def __init__(self, strategy=TranslationStrategy.FRENCH):
        super().__init__(strategy)

    def translate_in_french(self, text: str) -> str:
        return "french_literary_translation"

    def translate_in_spanish(self, text: str) -> str:
        return "spanish_literary_translation"


class TechnicalTranslator(Translator):
    def __init__(self, strategy=TranslationStrategy.FRENCH):
        super().__init__(strategy)

    def translate_in_french(self, text: str) -> str:
        return "french_technical_translation"

    def translate_in_spanish(self, text: str) -> str:
        return "spanish_technical_translation"

Note that the __init__ method is NOT an abstract method. That method should never be marked as an abstract method, and any method for which you rely on the implementation should not be marked as abstract.

The self._strategy = TranslationStrategy(strategy) line will ensure that the given strategy is a member of that enum. That is, it will automatically normalize input, and reject invalid values:

>>> TranslationStrategy('French')
<TranslationStrategy.FRENCH: 'french'>

>>> TranslationStrategy('french')
<TranslationStrategy.FRENCH: 'french'>

>>> TranslationStrategy('FRENCH')
<TranslationStrategy.FRENCH: 'french'>

>>> TranslationStrategy(TranslationStrategy.FRENCH)
<TranslationStrategy.FRENCH: 'french'>

>>> TranslationStrategy('foo')
Traceback (most recent call last):
...
ValueError: 'foo' is not a valid TranslationStrategy

In order to properly obtain a reference to a subclass' method, a reference to it must be stored once the subclass can be known. The self.strategy = getattr(self, f'translate_in_{self._strategy}') line stores a reference to the translate_in_french or translate_in_spanish method in the current object, which will be the one defined in the class you initialize. The reason the approach you used did not work was that it stored a reference to the abstract TranslationStrategy.translate_in_french or TranslationStrategy.translate_in_spanish method, not the one defined in the subclass.

Technically, the __init__ implementations in LiteraryTranslator and TechnicalTranslator are not strictly necessary here since they don't do anything except call super().__init__(...) with the same arguments as the parent class.

Lastly, stacking @classmethod with @abststractmethod results in an abstract class method, not an abstract instance (normal) method. Since these are intended to be normal methods, the @classmethod had to be omitted.

CodePudding user response:

The strategy pattern is much easier to implement in Python, where you can simply pass functions as arguments and return new functions from a function. No need for all the class boilerplate.

from typing import Callable

TranslationStrategy = Callable[[str], str]
Translator = Callable[[str], str]

def literary_french(text: str) -> str:
    ...


def literary_spanish(text: str) -> str:
    ...


def technical_french(text: str) -> str:
    ...


def technical_spanish(text: str) -> str:
    ...


def make_translator(strategy: TranslationStrategy) -> Translator:

    # Yes, this is simple enough we could just write
    #
    #     return strategy
    # 
    def translate(text: str) -> str:
        return strategy(text)
    return translate


translator = make_translator(technical_french)
print(translator("..."))
  • Related