Home > Net >  NumPy get indices per column of the entries which are equal to max value per column in a 2D-array
NumPy get indices per column of the entries which are equal to max value per column in a 2D-array

Time:02-20

I have an 2D-array:

A = np.array([[2,3,4],
              [2,0,4],
              [1,3,7]])

I am searching for the indices per column, which respresent the maximum value of this column without using a for loop.

What I would like to have, is something like:

max_rowIndices_perColumn = np.array([[0,1],[0,2],[2]])

I had the idea to use:

np.where(A== np.amax(A,axis=0)) 

but as in the second step, I would like to work with every specific column itself, I am not really happy with this idea.

Thank you in advance

CodePudding user response:

You need some deeper knowledge about behaviour of indexing.

Basically, np.where returns advanced indices of True cells in C order (row by row):

>>> np.where(mask)
(array([0, 0, 1, 2, 2]), array([0, 1, 0, 1, 2]))

but you need to do it in Fortran order (column by column) like so:

>>> np.where(mask, order='F') # not working, it doesn't support order parameter
(array([0, 1, 0, 2, 2]), array([0, 0, 1, 1, 2]))

It's not working but you could pass mask.T instead:

>>> np.where(mask.T) # fix
(array([0, 0, 1, 1, 2]), array([0, 1, 0, 2, 2]))

The remaining part is to split row indices into groups. In conclusion, you could solve your problem like so:

mask = A == np.amax(A, axis=0)
x, y = np.where(mask.T)
div_points = np.flatnonzero(np.diff(x))   1
np.split(y, div_points)
>>> [array([0, 1]), array([0, 2]), array([2])]

CodePudding user response:

Define a function to get indices of max value in a column:

def idxMax(col):
    _, _, inv = np.unique(-col, return_index=True, return_inverse=True)
    return np.where(inv == 0)[0].tolist()

Then generate the result as:

result = np.array([ idxMax(col) for col in A.T ], dtype=object)

For your source data, the result is:

array([list([0, 1]), list([0, 2]), list([2])], dtype=object)

Note that in general case there is no any guarantee that each column will return the same number of max indices, so the result array is a "ragged" one, and in this case Numpy requires dtype=object be passed.

But if it is enough for you to get a plain pythonic list of lists (instead of a Numpy array), you can shrink the above code to:

result = [ idxMax(col) for col in A.T ]

In this case the result is:

[[0, 1], [0, 2], [2]]
  • Related