Home > Software design >  Get indices of last N elements in each column of array A, but only those that are False in some mask
Get indices of last N elements in each column of array A, but only those that are False in some mask

Time:12-09

Let A be the following array

A = np.array([[2, 1, 2, 2],
              [1, 4, 0, 3],
              [0, 0, 3, 4],
              [3, 3, 1, 0],
              [4, 2, 4, 1]])

and let M be the following boolean mask

M = np.array([[ True, False, False, False],
              [ True, False, False, False],
              [False,  True, False,  True],
              [ True,  True, False, False],
              [False,  True,  True, False]])

Question 1:

How can I get the indices of the last N elements in each column of A such that they are False in the mask M?

In other words, I would like to get indices (for N = 2)

row_ixs = [2, 4, 0, 1, 2, 3, 3, 4]
col_ixs = [0, 0, 1, 1, 2, 2, 3, 3]

Question 2:

How can I get the indices of the elements in each column of A that are not among the first N elements of A that are also False in the mask M?

In other words, I would like to get indices (for N = 2)

row_ixs = [2, 3, 3, 4]
col_ixs = [2, 2, 3, 3]

CodePudding user response:

You can use:

M2 = ~M.T
M3 = M2[:, ::-1].cumsum(axis=1)[:, ::-1] <= 2

col_ixs, row_ixs = np.where(M2&M3)

Output:

# col_ixs
array([0, 0, 1, 1, 2, 2, 3, 3])

# row_ixs
array([2, 4, 0, 1, 2, 3, 3, 4])

steps

# invert boolean values and transpose
M2 = ~M.T

array([[False, False,  True, False,  True],
       [ True,  True, False, False, False],
       [ True,  True,  True,  True, False],
       [ True,  True, False,  True,  True]])

# get a reversed cumsum to identify the last N items
M2[:, ::-1].cumsum(axis=1)[:, ::-1]

array([[2, 2, 2, 1, 1],
       [2, 1, 0, 0, 0],
       [4, 3, 2, 1, 0],
       [4, 3, 2, 2, 1]])

# combine with original M2 to keep the last 2 True (with respect to M2)

array([[False, False,  True, False,  True],
       [ True,  True, False, False, False],
       [False, False,  True,  True, False],
       [False, False, False,  True,  True]])

question #2

M2 = ~M.T
M3 = M2.cumsum(axis=1) > 2

col_ixs, row_ixs = np.where(M2&M3)

Output:

array([2, 2, 3, 3]), array([2, 3, 3, 4])
  • Related