Home > Blockchain >  How does numpy scalar multiplication work?
How does numpy scalar multiplication work?

Time:11-21

I'm just curious about how numpy implements its scalar multiplication. In my understanding, x * y calls the __mul__ method of x, x.__mul__(y).

How would this work with this code:

x = 6
y = np.array([1,2])
print(x * y)

Wouldn't the code call x.__mul__(y)? And since x is not a numpy variable, how does numpy have any control over how this behaves?

CodePudding user response:

When Python encounters x * y it will indeed first call x.__mul__(y)(1). Any such method may return NotImplemented if it doesn't know how to handle the operand it received. This is what x.__mul__(y) does :

>>> x = 6
>>> y = np.array([1, 2])
>>> x.__mul__(y)
NotImplemented

Then, as a next step, Python calls the reflected operands method __rmul__, i.e. y.__rmul__(x):

>>> y.__rmul__(x)
array([ 6, 12])

So in this second step, it is up to the np.ndarray class to determine the outcome of this operation (2).

This can also be visualized with the help of a custom test class:

>>> class Test:
...     def __rmul__(self, other):
...         print('__rmul__', self, other)
... 
>>> x = 1
>>> y = Test()
>>> x.__mul__(y)
NotImplemented
>>> x * y
__rmul__ <__main__.Test object at 0x7f46c6acb668> 1

(1) Unless type(y) is a subclass of type(x) in which case y.__rmul__(x) takes precedence over x.__mul__(y).

(2) If both methods, x.__mul__ and y.__rmul__, return NotImplemented then a TypeError is raised.

  • Related