Home > Back-end >  Applying a condition to rows of a 2d array
Applying a condition to rows of a 2d array

Time:01-26

Consider a = np.array([0, 0, 1, 1, 2, 2, 3, 3])

In the multiset a, there are exactly 2 instances each of 0,1,2, and 3.

I want to find all permutations of a that meet a condition as we move through each row from left to right:

condition: the 1st instance of 0,1,2, and 3 must appear in that order, though they do not need to be consecutive.

[0, 1, , 2, , 3, , ] is ok, [0, 1, , 3, , 2, , ] is not ok

The 2nd instance of each number may appear anywhere in the row as long as it is after (to the right of) the 1st instance.

[0, 1, 0, 2, 2, 3, 1, 3] is ok

I've started by finding all 8!/2**4 = 2525 permutations of the multiset a:

from sympy.utilities.iterables import multiset_permutations
import numpy as np

a = np.array([0, 0, 1, 1, 2, 2, 3, 3])

resultList = []
for p in multiset_permutations(a):
    resultList.append(p)
    
out = np.array(resultList)

My difficulty is that I'm drowning in the details when I try to set the condition. To compound the problem, the actual array a could have up to 5 pairs of values. QUESTION: How can the condition be written so that I can eliminate, from array out, all permutation rows that do not satisfy the condition?

CodePudding user response:

Since you know your array consists of exactly pairs of the elements in np.arange(4), you can use np.argmax to check:

max_values = np.max(a)
uniques = np.arange(max_values   1)
# or you can just do
# uniques = np.unique(a)

resultList = []
for p in multiset_permutations(a):
    idx = np.argmax(p==uniques[:,None], axis=1)
    if (idx[:-1] < idx[1:]).all():
        resultList.append(p)

Then resultList would contains 420 permutations for 4 pairs; and 4725 for 5 pairs.

  • Related