Home > Net >  How to get index of multiple, possibly different, elements in numpy?
How to get index of multiple, possibly different, elements in numpy?

Time:10-04

I have a numpy array with many rows in it that look roughly as follows:

0, 50, 50, 2, 50, 1, 50, 99, 50, 50
50, 2, 1, 50, 50, 50, 98, 50, 50, 50
0, 50, 50, 98, 50, 1, 50, 50, 50, 50
0, 50, 50, 50, 50, 99, 50, 50, 2, 50
2, 50, 50, 0, 98, 1, 50, 50, 50, 50

I am given a variable n<50. Each row, of length 10, has the following in it:

  • Every number from 0 to n, with one possibly missing. In the example above, n=2.
  • Possibly a 98, which will be in the place of the missing number, if there is a number missing.
  • Possibly a 99, which will be in the place of the missing number, if there is a number missing, and there is not already a 98.
  • Many 50's.

What I want to get is an array with all the indices of the 0s in the first row, all the indices of the 1s in the second row, all the indices of the 2s in the third row, etc. For the above example, my desired output is this:

0, 6, 0, 0, 3
5, 2, 5, 5, 5
3, 1, 3, 8, 0

You may have noticed the catch: sometimes, exactly one of the numbers is replaced either by a 98, or a 99. It's pretty easy to write a for loop which determines which number, if any, was replaced, and uses that to get the array of indices.

Is there a way to do this with numpy?

CodePudding user response:

I don't think you're getting away without a for-loop here. But here's how you could go about it.

For each number in n, find all of the locations where it is known. Example:

locations = np.argwhere(data == 1)
print(locations)
[[0 5]
 [1 2]
 [2 5]
 [4 5]]

You can then turn this into a map for easy lookup per number in n:

known = {
    i: dict(np.argwhere(data == i))
    for i in range(n   1)
}
pprint(known)
{0: {0: 0, 2: 0, 3: 0, 4: 3},
 1: {0: 5, 1: 2, 2: 5, 4: 5},
 2: {0: 3, 1: 1, 3: 8, 4: 0}}

Do the same for the unknown numbers:

unknown = dict(np.argwhere((data == 98) | (data == 99)))
pprint(unknown)
{0: 7, 1: 6, 2: 3, 3: 5, 4: 4}

And now for each location in the result, you can lookup the index in the known list and fallback to the unknown.

result = np.array(
    [
        [known[i].get(j, unknown.get(j)) for j in range(len(data))]
        for i in range(n   1)
    ]
)
print(result)
[[0 6 0 0 3]
 [5 2 5 5 5]
 [3 1 3 8 0]]

Bonus: Getting fancy with dictionary constructor and unpacking:

from collections import OrderedDict

unknown = np.argwhere((data == 98) | (data == 99))
results = np.array([
    [*OrderedDict((*unknown, *np.argwhere(data == i))).values()]
    for i in range(n   1)
])
print(results)

CodePudding user response:

The follwing numpy solution rather aggressively uses the assumptions listed in OP. If they are not 100% guaranteed some more checks may be in order.

The mildly clever bit (even if I say so myself) here is to use the data array itself for finding the right destinations of their indices. For example, all the 2's need their indices stored in row 2 of the output array. Using this we can bulk store most of the indices in a single operation.

Example input is in array data:

n = 2
y,x = data.shape
out = np.empty((y,n 1),int)
# find 98 falling back to 99 if necessary
# and fill output array with their indices
# if neither exists some nonsense will be written but that does no harm
# most of this will be overwritten later
out.T[...] = ((data-98)&127).argmin(axis=1)
# find n 1 lowest values in each row
idx = data.argpartition(n,axis=1)[:,:n 1]
# construct auxiliary indexer
yr = np.arange(y)[:,None]
# put indices of low values where they belong
out[yr,data[yr,idx[:,:-1]]] = idx[:,:-1]
#      ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ 
#         the clever bit
# rows with no missing number still need the last value
nomiss, = (data[range(y),idx[:,n]] == n).nonzero()
out[nomiss,n] = idx[nomiss,n]
# admire
print(out.T)

outputs:

[[0 6 0 0 3]
 [5 2 5 5 5]
 [3 1 3 8 0]]
  • Related