Home > Software engineering >  Fastest way to find the maximum minimum value of 'connected' matrices
Fastest way to find the maximum minimum value of 'connected' matrices

Time:10-09

The answer for three matrices was given in this question, but I'm not sure how to apply this logic to an arbitrary amount of pairwise connected matrices:

f(i, j, k, l, ...) = min(A(i, j), B(i,k), C(i,l), D(j,k), E(j,l), F(k,l), ...)

Where A,B,... are matrices and i,j,... are indices that range up to the respective dimensions of the matrices. If we consider n indices, there are n(n-1)/2 pairs and thus matrices. I would like to find (i,j,k,...) such that f(i,j,k,l,...) is maximized. I am currently doing that as follows:

import numpy as np
import itertools

#             i  j  k  l  ...
dimensions = [50,50,50,50]
n_dims = len(dimensions)

pairs = list(itertools.combinations(range(n_dims), 2))

# Construct the matrices A(i,j), B(i,k), ...
matrices = [];
for pair in pairs:
    matrices.append(np.random.rand(dimensions[pair[0]], dimensions[pair[1]]))


# All the different i,j,k,l... combinations
combinations = itertools.product(*list(map(np.arange,dimensions)))
combinations = np.asarray(list(combinations))

# Find the maximum minimum
vals = []

for i in range(len(pairs)):
    pair = pairs[i]
    matrix = matrices[i]
    vals.append(matrix[combinations[:,pair[0]], combinations[:,pair[1]]])


f = np.min(vals,axis=0)

best_indices = combinations[np.argmax(f)]

print(best_indices, np.max(f))

[5 17 17 18] 0.932985854758534

This is faster than iterating over all (i, j, k, l, ...), but a lot of time is spent constructing the combinations and vals matrices. Is there an alternative way to do this where (1) the speed of numpy's matrix computation can be preserved and (2) I don't have to construct the memory-intensive vals matrices?

CodePudding user response:

Here is a generalisation of the 3D solution. I assume there are other (better?) ways of organising the recursion but this works well enough. It does a 6D example (product of dims 9x10^6) in <10 ms

Sample run, note that occasionally the indices returned by the two methods do not match. This is because they are not always unique, sometimes different index combinations yield the same maximum of minima. Also note that in the very end we do a single run of a huge 6D 9x10^12 example. Brute force is no longer viable on that, the smart method takes about 10 seconds.

trial 1
results identical True
results compatible True
brute force 276.8830654968042 ms
branch cut 9.971900499658659 ms
trial 2
results identical True
results compatible True
brute force 273.444719001418 ms
branch cut 9.236706099909497 ms
trial 3
results identical True
results compatible True
brute force 274.2998780013295 ms
branch cut 7.31226220013923 ms
trial 4
results identical True
results compatible True
brute force 273.0268925006385 ms
branch cut 6.956217200058745 ms
HUGE (100, 150, 200, 100, 150, 200) 9000000000000
branch cut 10246.754082996631 ms

Code:

import numpy as np
import itertools as it
import functools as ft

def bf(dims,pairs):
    dims,pairs = np.array(dims),np.array(pairs,object)
    n,m = len(dims),len(pairs)
    IDX = np.empty((m,n),object)
    Y,X = np.triu_indices(n,1)
    IDX[np.arange(m),Y] = slice(None)
    IDX[np.arange(m),X] = slice(None)
    idx = np.unravel_index(
        ft.reduce(np.minimum,(p[(*i,)] for p,i in zip(pairs,IDX))).argmax(),dims)
    return ft.reduce(np.minimum,(
        p[I] for p,I in zip(pairs,it.combinations(idx,2)))),idx

def cut(dims,pairs,offs=None):
    n = len(dims)
    if n<3:
        if n==2:
            A = pairs[0] if offs is None else np.minimum(
                pairs[0],np.minimum.outer(offs[0],offs[1]))
            idx = np.unravel_index(A.argmax(),dims)
            return A[idx],idx
        else:
            idx = offs[0].argmax()
            return offs[0][idx],(idx,)
    gmx = min(map(np.min,pairs))
    gidx = n * (0,)
    A = pairs[0] if offs is None else np.minimum(
        pairs[0],np.minimum.outer(offs[0],offs[1]))
    Y,X = np.unravel_index(A.argsort(axis=None)[::-1],dims[:2])
    for y,x in zip(Y,X):
        if A[y,x] <= gmx:
            return gmx,gidx
        coffs = [np.minimum(p1[y],p2[x])
                 for p1,p2 in zip(pairs[1:n-1],pairs[n-1:])]
        if not offs is None:
            coffs = [*map(np.minimum,coffs,offs[2:])]
        cmx,cidx = cut(dims[2:],pairs[2*n-3:],coffs)
        if cmx >= A[y,x]:
            return A[y,x],(y,x,*cidx)
        if gmx < cmx:
            gmx = min(A[y,x],cmx)
            gidx = y,x,*cidx
    return gmx,gidx

from timeit import timeit

IDX = 10,15,20,10,15,20

for rep in range(4):
    print("trial",rep 1)
    pairs = [np.random.rand(i,j) for i,j in it.combinations(IDX,2)]

    print("results identical",cut(IDX,pairs)==bf(IDX,pairs))
    print("results compatible",cut(IDX,pairs)[1]==bf(IDX,pairs)[1])
    print("brute force",timeit(lambda:bf(IDX,pairs),number=2)*500,"ms")
    print("branch cut",timeit(lambda:cut(IDX,pairs),number=10)*100,"ms")

IDX = 100,150,200,100,150,200
pairs = [np.random.rand(i,j) for i,j in it.combinations(IDX,2)]
print("HUGE",IDX,np.prod(IDX))
print("branch cut",timeit(lambda:cut(IDX,pairs),number=1)*1000,"ms")
  • Related