Home > Net >  __eq__ order enforcement in Python
__eq__ order enforcement in Python

Time:12-10

A slightly long question to sufficiently explain the background...

Assuming there's a builtin class A:

class A:
    def __init__(self, a=None):
        self.a = a
    def __eq__(self, other):
        return self.a == other.a

It's expected to compare in this way:

a1, a2 = A(1), A(2)
a1 == a2  # False

For some reason, the team introduced a wrapper on top of it (The code example doesn't actually wrap A to simplify the code complexity.)

class WrapperA:
    def __init__(self, a=None):
        self.pa = a
    def __eq__(self, other):
        return self.pa == other.pa

Again, it's expected to compare in this way:

wa1, wa2 = WrapperA(1), WrapperA(2)
wa1 == wa2  # False

Although it's expected to use either A or WrapperA, the problem is some code bases contain both usages, thus following comparison failed:

a, wa = A(), WrapperA()
wa == a  # AttributeError
a == wa  # AttributeError

A known solution is to modify __eq__:

For wa == a:

class WrapperA:
    def __init__(self, a=None):
        self.pa = a
    def __eq__(self, other):
        if isinstance(other, A):
            return self.pa == other.a
        return self.pa == other.pa

For a == wa:

class A:
    def __init__(self, a=None):
        self.a = a
    def __eq__(self, other):
        if isinstance(other, WrapperA):
            return self.a == other.pa
        return self.a == other.a

Modifying WrapperA is expected. For A, since it is a builtin thing, two solutions are:

  1. Use setattr to extend A to support WrapperA.
setattr(A, '__eq__', eq_that_supports_WrapperA)
  1. Enforce developer to only compare wa == a (And then don't care about a == wa).

1st option is obviously ugly with duplicated implementation, and 2nd gives developer unnecessary "surprise". So my question is, is there an elegant way to replace any usage of a == wa to wa == a by the Python implementation internally?

CodePudding user response:

Similar to Ron Serruyas answer:

This uses __getattr__ instead of __getattribute__, where the first one is only called if the second one raises an AttributeError or explicitly calls it (ref). This means if the wrapper does not implement __eq__ and the equality should only be performed on the underlying data structure (stored in objects of class A), a working example is given by:

class A(object):
  def __init__(self, internal_data=None):
    self._internal_data = internal_data

  def __eq__(self, other):
    return self._internal_data == other._internal_data

class WrapperA(object):
  def __init__(self, a_object: A):
    self._a = a_object

  def __getattr__(self, attribute):
    if attribute != '_a':  # This is neccessary to prevent recursive calls
      return getattr(self._a, attribute)

a1 = A(internal_data=1)
a2 = A(internal_data=2)

wa1 = WrapperA(a1)
wa2 = WrapperA(a2)    

print(
    a1 == a1,
    a1 == a2,
    wa1 == wa1,
    a1 == wa1,
    a2 == wa2,
    wa1 == a1)

>>> True False True True True True

CodePudding user response:

Quoting the comment from MisterMiyagi under the question:

Note that == is generally expected to work across all types. A.__eq__ requiring other to be an A is actually a bug that should be fixed. It should at the very least return NotImplemented when it cannot make a decision

This is important, not just a question of style. In fact, according to the documentation:

When a binary (or in-place) method returns NotImplemented the interpreter will try the reflected operation on the other type.

Thus if you just apply MisterMiyagi's comment and fix the logic of __eq__, you'll see your code works fine already:

class A:
    def __init__(self, a=None):
        self.a = a

    def __eq__(self, other):
        if isinstance(other, A):
            return self.a == other.a
        return NotImplemented


class WrapperA:
    def __init__(self, a=None):
        self.pa = a

    def __eq__(self, other):
        if isinstance(other, A):
            return self.pa == other.a
        elif isinstance(other, WrapperA):
            return self.pa == other.pa
        return NotImplemented

# Trying it
a = A(5)
wrap_a = WrapperA(5)

print(a == wrap_a)
print(wrap_a == a)

wrap_a.pa = 7
print(a == wrap_a)
print(wrap_a == a)
print(f'{wrap_a.pa=}')

Yields:

True
True
False
False
wrap_a.pa=7

Under the hood, a == wrap_a calls A.__eq__ first, which returns NotImplemented. Python then automatically tries WrapperA.__eq__ instead.

CodePudding user response:

I dont really like this whole thing, since I think that wrapping a builtin and using different attribute names will lead to unexpected stuff, but anyway, this will work for you

import inspect


class A:
    def __init__(self, a=None):
        self.a = a

    def __eq__(self, other):
        return self.a == other.a


class WrapperA:
    def __init__(self, a=None):
        self.pa = a

    def __eq__(self, other):
        if isinstance(other, A):
            return self.pa == other.a
        return self.pa == other.pa

    def __getattribute__(self, item):
        # Figure out who tried to get the attribute
        # If the item requested was 'a', check if A's __eq__ method called us,
        # in that case return pa instead
        caller = inspect.stack()[1]
        if item == 'a' and getattr(caller, 'function') == '__eq__' and isinstance(caller.frame.f_locals.get('self'), A):
            return super(WrapperA, self).__getattribute__('pa')
        return super(WrapperA, self).__getattribute__(item)

a = A(5)
wrap_a = WrapperA(5)

print(a == wrap_a)
print(wrap_a == a)

wrap_a.pa = 7
print(a == wrap_a)
print(wrap_a == a)
print(f'{wrap_a.pa=}')

Output:

True
True
False
False
wrap_a.pa=7
  • Related