Home > front end >  how to reverse index a 2-d array
how to reverse index a 2-d array

Time:01-13

I have a 2d MxN array A , each row of which is a sequence of indices, padded by -1's at the end e.g.:

[[ 2 1 -1 -1 -1]
 [ 1 4  3 -1 -1]
 [ 3 1  0 -1 -1]]

I have another MxN array of float values B:

[[ 0.7 0.4 1.5 2.0 4.4 ]
 [ 0.8 4.0  0.3 0.11 0.53]
 [ 0.6 7.4  0.22 0.71 0.06]]

and I want to use the indices in A to filter B i.e. for each row, only the indices present in A retain their values, and the values at all other locations are set to 0.0, i.e. the result would look like:

[[ 0.0 0.4 1.5 0.0 0.0 ]
 [ 0.0 4.0  0.0 0.11 0.53 ]
 [ 0.6 7.4  0.0 0.71 0.0]]

What's a good way to do this in "pure" numpy? (I would like to do this in pure numpy so I can jit it in jax.

CodePudding user response:

You can use broadcasting, but note that it will create a large intermediate array of shape (M, N, N) (in pure numpy at least):

import numpy as np

A = ...
B = ...

M, N = A.shape

out = np.where(np.any(A[..., None] == np.arange(N), axis=1), B, 0.0)

out:

array([[0.  , 0.4 , 1.5 , 0.  , 0.  ],
       [0.  , 4.  , 0.  , 0.11, 0.53],
       [0.6 , 7.4 , 0.  , 0.71, 0.  ]])

CodePudding user response:

Numpy supports fancy indexing. Ignoring the "-1" entries for the moment, you can do something like this:

index = (np.arange(B.shape[0]).reshape(-1, 1), A)
result = np.zeros_like(B)
result[index] = B[index]

This works because indices are broadcasted. The column np.arange(B.shape[0]).reshape(-1, 1) matches all the elements of a given row of A to the corresponding row in B and result.

This example does not address the fact that -1 is a valid numpy index. You need to clear the elements that correspond to -1 in A when 4 (the last column) is not present in that row:

mask = (A == -1).any(axis=1) & (A != A.shape[1] - 1).all(axis=1)
result[mask, -1] = 0.0

Here, the mask is [True, False, True], indicating that even though the second row has a -1 in it, it also contains a 4.

This approach is fairly efficient. It will create no more than a couple of boolean arrays of the same shape as A for the mask.

CodePudding user response:

Another possible solution:

maxr = np.max(A, axis=1)
A = np.where(A == -1, maxr[:, None], A)
mask = np.zeros(np.shape(B), dtype=bool)
np.put_along_axis(mask, A, True, axis=1) 
np.where(mask, B, 0)

Output:

array([[0.  , 0.4 , 1.5 , 0.  , 0.  ],
       [0.  , 4.  , 0.  , 0.11, 0.53],
       [0.6 , 7.4 , 0.  , 0.71, 0.  ]])
  • Related