Home > Software engineering >  Custom non-linear matrix multiplication in NumPy
Custom non-linear matrix multiplication in NumPy

Time:06-23

Suppose I have to matrices U and W:

U = np.arange(6*2).reshape((6,2))
W = np.arange(5*2).reshape((5,2))

For a standard linear multiplication, I could do:

U @ W.T
array([[  1,   3,   5,   7,   9],
       [  3,  13,  23,  33,  43],
       [  5,  23,  41,  59,  77],
       [  7,  33,  59,  85, 111],
       [  9,  43,  77, 111, 145],
       [ 11,  53,  95, 137, 179]])

But I could also (technically) define a linear multiplication function, do this column-wise and sum in a for-loop:

def mult(U, W, i):
  return U[:, [i]] @ W.T[[i],:]

sum([mult(U, W, i) for i in range(2)]) #1
array([[  1,   3,   5,   7,   9],
       [  3,  13,  23,  33,  43],
       [  5,  23,  41,  59,  77],
       [  7,  33,  59,  85, 111],
       [  9,  43,  77, 111, 145],
       [ 11,  53,  95, 137, 179]])

Now suppose mult() isn't linear anymore, it is non-linear, custom, for example:

def mult(U, W, i):
  return (U[:, [i]] @ W.T[[i],:]) * np.cos(U[:, [i]] @ W.T[[i],:])

sum([mult(U, W, i) for i in range(2)]) #2

You can verify this isn't identical to (U @ W.T) * np.cos(U @ W.T). But I wonder is there a more compact way of writing #2, just as there is a more compact way of writing #1 if mult() is linear. Efficiency would be nice but I'm not dealing with huge matrices.

CodePudding user response:

@, like np.dot is a matrix multiplication, involving what we often call a sum-of-products. This is a basic linear algebra operation, and np.matmul uses highly efficient compiled libraries to do this (where possible).

Your sum([mult(...)) is doing that - take the row/column products and summing them. The compiled code probably uses more efficient methods that work well in iterative c or Fortran.

Your mult function could use broadcasted element-wise multiplication. For one i:

In [43]: i=1;U[:, [i]] @ W.T[[i],:]     # (6,1) @ (1,5) => (6,5)
Out[43]: 
array([[ 1,  3,  5,  7,  9],
       [ 3,  9, 15, 21, 27],
       [ 5, 15, 25, 35, 45],
       [ 7, 21, 35, 49, 63],
       [ 9, 27, 45, 63, 81],
       [11, 33, 55, 77, 99]])

In [44]: i=1;U[:, [i]] * W.T[[i],:]
Out[44]: 
array([[ 1,  3,  5,  7,  9],
       [ 3,  9, 15, 21, 27],
       [ 5, 15, 25, 35, 45],
       [ 7, 21, 35, 49, 63],
       [ 9, 27, 45, 63, 81],
       [11, 33, 55, 77, 99]])

And without the list comprehension this can be written as:

In [46]: (U[:,None,:]*W[None,:,:]).shape
Out[46]: (6, 5, 2)

In [47]: (U[:,None,:]*W[None,:,:]).sum(axis=2)
Out[47]: 
array([[  1,   3,   5,   7,   9],
       [  3,  13,  23,  33,  43],
       [  5,  23,  41,  59,  77],
       [  7,  33,  59,  85, 111],
       [  9,  43,  77, 111, 145],
       [ 11,  53,  95, 137, 179]])

As for your version with `np.cos:

In [48]: def mult(U, W, i):
    ...:   return (U[:, [i]] @ W.T[[i],:]) * np.cos(U[:, [i]] @ W.T[[i],:])
    ...: sum([mult(U, W, i) for i in range(2)]) #2
Out[48]: 
array([[ 5.40302306e-01, -2.96997749e 00,  1.41831093e 00,
         5.27731578e 00, -8.20017236e 00],
       [-2.96997749e 00, -1.08147468e 01, -1.25593190e 01,
        -1.37606696e 00, -2.32102995e 01],
       [ 1.41831093e 00, -1.25593190e 01,  9.45751861e 00,
        -2.14489310e 01,  5.03346370e 01],
       [ 5.27731578e 00, -1.37606696e 00, -2.14489310e 01,
         1.01223418e 01,  3.13845563e 01],
       [-8.20017236e 00, -2.32102995e 01,  5.03346370e 01,
         3.13845563e 01,  8.79904273e 01],
       [ 4.86826779e-02,  7.72350858e 00, -2.54605509e 01,
        -5.95298563e 01, -4.88871235e 00]])

I can use the same outer/sum format:

In [49]: (U[:,None,:]*W[None,:,:]*np.cos(U[:,None,:]*W[None,:,:])).sum(axis=2)
Out[49]: 
array([[ 5.40302306e-01, -2.96997749e 00,  1.41831093e 00,
         5.27731578e 00, -8.20017236e 00],
       [-2.96997749e 00, -1.08147468e 01, -1.25593190e 01,
        -1.37606696e 00, -2.32102995e 01],
       [ 1.41831093e 00, -1.25593190e 01,  9.45751861e 00,
        -2.14489310e 01,  5.03346370e 01],
       [ 5.27731578e 00, -1.37606696e 00, -2.14489310e 01,
         1.01223418e 01,  3.13845563e 01],
       [-8.20017236e 00, -2.32102995e 01,  5.03346370e 01,
         3.13845563e 01,  8.79904273e 01],
       [ 4.86826779e-02,  7.72350858e 00, -2.54605509e 01,
        -5.95298563e 01, -4.88871235e 00]])

And since the outer product is used twice, we can use a temporary variable:

In [51]: temp=U[:,None,:]*W[None,:,:]; 
         (temp*np.cos(temp)).sum(axis=2)
Out[51]: 
array([[ 5.40302306e-01, -2.96997749e 00,  1.41831093e 00,
         5.27731578e 00, -8.20017236e 00],
       [-2.96997749e 00, -1.08147468e 01, -1.25593190e 01,
        -1.37606696e 00, -2.32102995e 01],
       [ 1.41831093e 00, -1.25593190e 01,  9.45751861e 00,
        -2.14489310e 01,  5.03346370e 01],
       [ 5.27731578e 00, -1.37606696e 00, -2.14489310e 01,
         1.01223418e 01,  3.13845563e 01],
       [-8.20017236e 00, -2.32102995e 01,  5.03346370e 01,
         3.13845563e 01,  8.79904273e 01],
       [ 4.86826779e-02,  7.72350858e 00, -2.54605509e 01,
        -5.95298563e 01, -4.88871235e 00]])

The fact that you can't simply interchange the multiplication and sum steps is a matter of basic algebra.

To get

a1*b1   a2*b2   

from

(a1 a2)*(b1 b2) => a1*b1   a1*b2   a2*b1   a2*b2

the a1*b2 a2*b1 terms have to sum to zero, as with the magnitude of a complex number:

In [53]: (1 4j)*(1-4j)
Out[53]: (17 0j)    # (1 16)

The sum of products cannot, in general be converted to a product of sums.

  • Related