Is that any reasonable way to use np.linalg.multi_dot() function with Nx2x2 arrays like functools.reduce(np.matmul, Nx2x2_arrays)? Please, see example below.
import numpy as np
from functools import reduce
m1 = np.array(range(16)).reshape(4, 2, 2)
m2 = m1.copy()
m3 = m1.copy()
reduce(np.matmul, (m1, m2, m3))
result - 4x2x2 array:
array([[[ 6, 11],
[ 22, 39]],
[[ 514, 615],
[ 738, 883]],
[[ 2942, 3267],
[ 3630, 4031]],
[[ 8826, 9503],
[10234, 11019]]])
As you see, np.matmul treats 4x2x2 3-D arrays like 1-D arrays of 2x2 matrices. Can I do the same using np.linalg.multi_dot() instead of reduce(np.matmul) and, if yes, will it lead to any performance improvement?
CodePudding user response:
np.linalg.multi_dot()
tries to optimize the operation by finding the order of dot products that leads to the fewest multiplications overall.
As all your matrices are square, the order of dot products does not matter and you will always end up with the same number of multiplications.
Internally, np.linalg.multi_dot()
doesn't run any C code but merely calls out to np.dot()
, so you can do the same:
functools.reduce(np.matmul, (m1, m2, m3))
or simply
m1 @ m2 @ m3
CodePudding user response:
You could also use np.einsum()
:
np.einsum('ijk,ikl,ilm->ijm',m1,m2,m3)