Home > OS >  Check if 2D sub-array is ordered - Pyhthon JAX
Check if 2D sub-array is ordered - Pyhthon JAX

Time:05-26

Let us suppose that we have an array ordered. We want to check if the sub-arrays t and t_inv are following the same order as the imposed order inorder array.

Reading from left to right: the first element is [0,0] and so on until [0,3]. t_inv is inversed because the first to elements are swapped, they do not follow the ordering as in ordered.

# imposed order 
ordered = jnp.array([[0, 0],[0,1],[0,2],[0,3]])

# array with permuted order
t = jnp.array([[[0, 0],[0, 1], [0,3]]])
t_inv = jnp.array([[[0, 1],[0, 0], [0,3]]])

I expect the following:

 result: ordered(t) = 1, because "ordered"  
and ordered(t_inv) = -1, because "swapped/not ordered"

How can you check that the sub arrays are indeed part of the ordered array and ouput whether the order is correct or not?

CodePudding user response:

You could do something like this:

import jax.numpy as jnp

# imposed order 
ordered = jnp.array([[0, 0],[0,1],[0,2],[0,3]])

# array with permuted order
t = jnp.array([[0, 0],[0, 1], [0,3]])
t_inv = jnp.array([[0, 1],[0, 0], [0,3]])


def is_sorted(t, ordered):
  index = jnp.where((t[:, None] == ordered).all(-1))[1]
  return jnp.where((index == jnp.sort(index)).all(), 1, -1)

print(is_sorted(t, ordered))
# 1
print(is_sorted(t_inv, ordered))
# -1

Scaling-wise, it might be faster to use a solution based on searchsorted, but the current implementation of jnp.searchsorted in JAX is relatively slow because XLA doesn't have any native binary search algorithm, so in practice the full pairwise comparison can often be more performant.

  • Related