Home > Blockchain >  Fastest way to find the maximum minimum value of three 'connected' matrices
Fastest way to find the maximum minimum value of three 'connected' matrices

Time:10-06

The answer for two matrices was given in this question, but I'm not sure how to apply this logic to three pairwise connected matrices since there are no 'free' indices. I want to maximize the following function:

f(i, j, k) = min(A(i, j), B(j, k), C(i,k))

Where A, B and C are matrices and i, j and k are indices that range up to the respective dimensions of the matrices. I would like to find (i, j, k) such that f(i, j, k) is maximized. I am currently doing that as follows:

import numpy as np
import itertools

I = 100
J = 150
K = 200

A = np.random.rand(I, J)
B = np.random.rand(J, K)
C = np.random.rand(I, K)

# All the different i,j,k
combinations = itertools.product(np.arange(I), np.arange(J), np.arange(K))
combinations = np.asarray(list(combinations))

A_vals = A[combinations[:,0], combinations[:,1]]
B_vals = B[combinations[:,1], combinations[:,2]]
C_vals = C[combinations[:,0], combinations[:,2]]

f = np.min([A_vals,B_vals,C_vals],axis=0)

best_indices = combinations[np.argmax(f)]
print(best_indices)

[ 49 14 136]

This is faster than iterating over all (i, j, k), but a lot of (and most of the) time is spent constructing the _vals matrices. This is unfortunate, because they contain many many duplicate values as the same i, j and k appear multiple times. Is there a way to do this where (1) the speed of numpy's matrix computation can be preserved and (2) I don't have to construct the memory-intensive _vals matrices.

In other languages you could maybe construct the matrices so that they contain pointers to A, B and C, but I do not see how to achieve this in Python.

CodePudding user response:

Instead of using itertools, you can "build" the combinations with repeats and tiles:

A_=np.repeat(A.reshape((-1,1)),K,axis=0).T
B_=np.tile(B.reshape((-1,1)),(I,1)).T
C_=np.tile(C,J).reshape((-1,1)).T

And passing them to np.min:

print((t:=np.argmax(np.min([A_,B_,C_],axis=0)) , t//(K*J),(t//K)%J, t%K,))

With timeit 10 repetitions of your code takes around 18 seconds and with numpy only about 1 second.

CodePudding user response:

We can either brute force it using numpy broadcasting or try a bit of smart branch cutting:

import numpy as np

def bf(A,B,C):
    I,J = A.shape
    J,K = B.shape
    return np.unravel_index((np.minimum(np.minimum(A[:,:,None],C[:,None,:]),B[None,:,:])).argmax(),(I,J,K))

def cut(A,B,C):
    gmx = min(A.min(),B.min(),C.min())
    I,J = A.shape
    J,K = B.shape
    Y,X = np.unravel_index(A.argsort(axis=None)[::-1],A.shape)
    for y,x in zip(Y,X):
        if A[y,x] <= gmx:
            return gamx
        curr = np.minimum(B[x,:],C[y,:])
        camx = curr.argmax()
        cmx = curr[camx]
        if cmx >= A[y,x]:
            return y,x,camx
        if gmx < cmx:
            gmx = cmx
            gamx = y,x,camx
    return gamx
            
from timeit import timeit

I = 100
J = 150
K = 200

for rep in range(4):
    print("trial",rep 1)
    A = np.random.rand(I, J)
    B = np.random.rand(J, K)
    C = np.random.rand(I, K)

    print("results identical",cut(A,B,C)==bf(A,B,C))
    print("brute force",timeit(lambda:bf(A,B,C),number=2)*500,"ms")
    print("branch cut",timeit(lambda:cut(A,B,C),number=10)*100,"ms")

It turns out that at the given sizes branch cutting is well worth it:

trial 1
results identical True
brute force 169.74265850149095 ms
branch cut 1.951422297861427 ms
trial 2
results identical True
brute force 180.37619898677804 ms
branch cut 2.1000938024371862 ms
trial 3
results identical True
brute force 181.6371419990901 ms
branch cut 1.999850495485589 ms
trial 4
results identical True
brute force 217.75578951928765 ms
branch cut 1.5871295996475965 ms

How does the branch cutting work?

We pick one array (A, say) and sort it from largest to smallest. We then go through the array one by one comparing each value to the appropriate values from the other arrays and keeping track of the running maximum of minima. As soon as the maximum is no smaller than the remaining values in A we are done. As this will typically happen rather soonish we get a huge saving.

  • Related