Home > front end >  Fastest way to multiply and sum 4D array with 2D array in python?
Fastest way to multiply and sum 4D array with 2D array in python?

Time:12-18

Here's my problem. I have two matrices A and B, with complex entries, of dimensions (n,n,m,m) and (n,n) respectively.

Below is the operation I perform to get a matrix C -

C = np.sum(B[:,:,None,None]*A, axis=(0,1))

Computing the above once takes about 6-8 seconds. Since I have to compute many such Cs, it takes a lot of time. Is there a faster way to do this? (I'm doing these using JAX NumPy on a multi-core CPU; normal NumPy takes even longer)

n=77 and m=512, if you are wondering. I can parallelize as I'm working on a cluster, but the sheer size of the arrays consumes a lot of memory.

CodePudding user response:

It looks like you want einsum:

C = np.einsum('ijkl,ij->kl', A, B)

With numpy on a Colab CPU I get this:

import numpy as np
x = np.random.rand(50, 50, 500, 500)
y = np.random.rand(50, 50)

def f1(x, y):
  return np.sum(y[:,:,None,None]*x, axis=(0,1))

def f2(x, y):
  return np.einsum('ijkl,ij->kl', x, y)

np.testing.assert_allclose(f1(x, y), f2(x, y))

%timeit f1(x, y)
# 1 loop, best of 5: 1.52 s per loop
%timeit f2(x, y)
# 1 loop, best of 5: 620 ms per loop
  • Related