I asked a question here with the details: https://math.stackexchange.com/questions/4381785/possibly-speed-up-matrix-multiplications
In short, I am trying to create a P x N matrix, X, with typical element: \sum_{j,k;j,k \neq i} w_{jp} A_{jk} Y_{kp}, where w is P x N, A is N x N and Y is P x N. See the link above for a markup version of that formula.
I'm providing a mwe here to see how I can correct the code (the calculations seem correct, just incomplete see below) and more importantly speed this up however possible:
w = np.array([[2,1],[3,7]])
A = np.array([[2,1],[9,-1]])
Y = np.array([[6,2],[11,8]])
N=w.shape[1]
P=w.shape[0]
X = np.zeros((P, N))
for p in range(P) :
for i in range(N-1):
for j in range(N-1):
X[p,i] = np.delete(w,i,1)[i,p]*np.delete(np.delete(A,i,0),i,1)[i,j]*np.delete(Y.T,i,0)[j,p]
The output looks like:
array([[ -2. , 0. ],
[-56. , 0.]])
If we set (i,p) = to the (1,1) element of X_{ip}, the value can be understood using the formula provided above:
sum_{j,k;j,k \neq i} w_{j1} A_{jk} Y_{k1} = w_12 A_22 Y_12 = 1 * -1 * 2 = -2 as it is in the output.
the (1,2) element of X_{ip} should be: sum_{j,k;j,k \neq i} w_{j2} A_{jk} Y_{k2} = w_22 A_22 Y_22 = 7 * -1 * 8 = -56 as it is in the output.
But I am not getting the correct answer for the final column of X because my range
is to (N-1)
not N
because I received an IndexError
out of bounds when it is N
. More importantly, here N=P=2, but I have large N and P and the code, as is, takes a very long time to run. Any suggestions would be greatly appreciated.
CodePudding user response:
Since the delete
functions depend only on i
, I factored them out, and reordered the loops. Also corrected the w1
index order.
In [274]: w = np.array([[2,1],[3,7]])
...: A = np.array([[2,1],[9,-1]])
...: Y = np.array([[6,2],[11,8]])
...: N=w.shape[1]
...: P=w.shape[0]
...: X = np.zeros((P, N))
...: for i in range(N-1):
...: print('i',i)
...: w1 = np.delete(w,i,1)
...: a1 = np.delete(np.delete(A,i,0),i,1)
...: y1 = np.delete(Y.T,i,0)
...: print(w1.shape, a1.shape, y1.shape)
...: print(w1@a1@y1)
...: print(np.einsum('ij,jk,li->i',w1,a1,y1))
...: for p in range(P):
...: for j in range(N-1):
...: X[p,i] = w1[p,i]*a1*y1[j,p]
...:
i 0
(2, 1) (1, 1) (1, 2)
[[ -2 -8]
[-14 -56]]
[ -2 -56]
In [275]: X
Out[275]:
array([[ -2., 0.],
[-56., 0.]])
Your [-2,-56] are the diagonal of w1@a1@y1
, or the einsum
. The 0's are from the original np.zeros
because i
is only on range(1)
.
This should be faster because the delete
is not repeated unnecessarily. np.delete
is still relatively expensive, but I haven't tried to figure out exactly what you are doing.
Didn't your question initially have (2,3)
and (3,3)
arrays? That, or something a bit larger, may be more general and informative.