Home > OS >  find top_k element of numpy ndarray and ignore zero
find top_k element of numpy ndarray and ignore zero

Time:07-27

Given a numpy ndarray like the following

x = [[4.,0.,2.,0.,8.],
     [1.,3.,0.,9.,5.],
     [0.,0.,4.,0.,1.]]

I want to find the indices of the top k (e.g. k=3) elements of each row, excluding 0, if possible. If there are less than k positive elements, then just return their indices (in a sorted way).

The result should look like (a list of array)

res = [[4, 0, 2],
       [3, 4, 1],
       [2, 4]]

or just one flatten array

res = [4,0,2,3,4,2,2,4]

I know argsort can find the indices of top k elements in a sorted order. But I am not sure how to filter out the 0.

CodePudding user response:

You can use numpy.argsort with (-num) for getting index as descending. then use numpy.take_along_axis for getting values base index of 2D sorted. Because you want to ignore zero you can insert zero for other columns after three (as you mention in the question). At the end return value from the sorted values that is not zero.

x = np.array([[4.,0.,2.,0.,8.],[1.,3.,0.,9.,5.],[0.,0.,4.,0.,1.]])
idx_srt = np.argsort(-x)
val_srt = np.take_along_axis(x, idx_srt, axis=-1)
val_srt[:, 3:] = 0
res = idx_srt[val_srt!=0]
print(res)

[4 0 2 3 4 1 2 4]

CodePudding user response:

Try one of these two:

k = 3

res = [sorted(range(len(r)), key=(lambda i: r[i]), reverse=True)[:min(k, len([n for n in r if n > 0]))] for r in x]

or

res1 = [np.argsort(r)[::-1][:min(k, len([n for n in r if n > 0]))] for r in x]

CodePudding user response:

I came up with the following solution:

top_index = score.argsort(axis=1) # score here is my x
positive = (score > 0).sum(axis=1)
positive = np.minimum(positive, k) # top k

# broadcasting trick to get mask matrix that selects top k (k = min(2000, num of positive scores))
r = np.arange(score.shape[1])
mask = (positive[:,None] > r) 
top_index_flatten = top_index[:, ::-1][mask]

I compare my result with the one suggested by @I'mahdi and they are consistent.

  • Related