Home > OS >  Fastest way to get max frequency element for every row of numpy matrix
Fastest way to get max frequency element for every row of numpy matrix

Time:10-02

Given a 2d numpy matrix, X of shape [m,n] whose all values are guaranteed to be integers between 0 and 9 inclusive, I wish to calculate for every row, the value which occurs most frequently in that particular row (to break ties, return the highest value), and output this max-value array of length m. A short example would be as follows:

X = [[1,2,3,4],
     [0,0,6,9],
     [5,7,7,5],
     [1,0,0,0],
     [1,8,1,8]]

The output for the above matrix should be:

y = [4,0,7,0,8]

Consider the first row - all elements occur with same frequency, hence the numerically greatest value with highest frequency is 4. In the second row, there is only one number 0 with the highest frequency. In the third row, both 5 and 7 occur twice, hence, 7 is chosen and so on.

I could do this by maintaining collections.Counter objects for each row and then choosing the number satisfying the criteria. A naive implementation which I tried:

from collections import Counter 
X = np.array([[1,2,3,4],[0,0,6,9],[5,7,7,5],[1,0,0,0],[1,8,1,8]])
y = np.zeros(len(X), dtype=int)

for i in range (len(X)):
    freq_count = Counter (X[i])
    max_freq, max_freq_val = 0, -1
    for val in range (10):
        if (freq_count.get(val, 0) >= max_freq):
            max_freq = freq_count.get(val, 0)
            max_freq_val = val
    y[i] = max_freq_val

print (y) #prints [4 0 7 0 8]

But using Counters is not fast enough. Is it possible to improve the running time? Maybe by also using vectorization? It is given that m = O(5e4) and n = 45.

CodePudding user response:

Given than the numbers are always integers between 0 and 9, you could use numpy.bincount to count the number of occurrences, then use numpy.argmax to find the last appearance (using a reversed view [::-1]):

import numpy as np

X = np.array([[1, 2, 3, 4],
              [0, 0, 6, 9],
              [5, 7, 7, 5],
              [1, 0, 0, 0],
              [1, 8, 1, 8]])

res = [9 - np.bincount(row, minlength=10)[::-1].argmax() for row in X]
print(res)

Output

[4, 0, 7, 0, 8]

According to the timings here np.bincount is pretty fast. For more details on using argmax to find the last occurrence of the max value read this

  • Related