Let us suppose that we have an array ordered
. We want to check if the sub-arrays t
and t_inv
are following the same order as the imposed order inorder
array.
Reading from left to right: the first element is [0,0]
and so on until [0,3]
.
t_inv
is inversed because the first to elements are swapped, they do not follow the ordering as in ordered
.
# imposed order
ordered = jnp.array([[0, 0],[0,1],[0,2],[0,3]])
# array with permuted order
t = jnp.array([[[0, 0],[0, 1], [0,3]]])
t_inv = jnp.array([[[0, 1],[0, 0], [0,3]]])
I expect the following:
result: ordered(t) = 1, because "ordered"
and ordered(t_inv) = -1, because "swapped/not ordered"
How can you check that the sub arrays are indeed part of the ordered array and ouput whether the order is correct or not?
CodePudding user response:
You could do something like this:
import jax.numpy as jnp
# imposed order
ordered = jnp.array([[0, 0],[0,1],[0,2],[0,3]])
# array with permuted order
t = jnp.array([[0, 0],[0, 1], [0,3]])
t_inv = jnp.array([[0, 1],[0, 0], [0,3]])
def is_sorted(t, ordered):
index = jnp.where((t[:, None] == ordered).all(-1))[1]
return jnp.where((index == jnp.sort(index)).all(), 1, -1)
print(is_sorted(t, ordered))
# 1
print(is_sorted(t_inv, ordered))
# -1
Scaling-wise, it might be faster to use a solution based on searchsorted
, but the current implementation of jnp.searchsorted
in JAX is relatively slow because XLA doesn't have any native binary search algorithm, so in practice the full pairwise comparison can often be more performant.