Home > Software design >  Find pivots in tensor that maximize sum of values
Find pivots in tensor that maximize sum of values

Time:04-30

Let's suppose I have a matrix like this:

[[15,10,8],
[11,5,8],
[9,14,4]]

I need to write a function that, for each row, returns the indices of the maximum value, without repeating the same column index.

Given the previous matrix, the best solution would be the following:

[[0,0],
[1,2],
[2,1]]

That's because the sum of the values given by those indices is (15 8 14 = 37) and the summed elements' index don't get repeated in the output tensor.

This is needed in a loss function so I need it written only in tensorflow.

Thanks

CodePudding user response:

That is called an assignment problem. This can be solved as an LP (Linear Programming) problem:

  max sum((i,j), a[i,j]*x[i,j])
  subject to
       sum(j, x[i,j]) = 1  ∀i
       sum(i, x[i,j]) = 1  ∀j
       x[i,j] ∈ [0,1]

This can be solved with any LP solver.

There are also specialized algorithms for the assignment problem. See e.g.: https://docs.scipy.org/doc/scipy-0.18.1/reference/generated/scipy.optimize.linear_sum_assignment.html.

  • Related