Home > Software design >  Joint construction of a random permutation and its inverse using NumPy
Joint construction of a random permutation and its inverse using NumPy

Time:08-27

I am looking to construct a random permutation of [1, 2, ..., n] along with its inverse using NumPy. In my application, n can be on the order of 100 million so I am looking for a solution that constructs both the permutation and its inverse in minimum time.

What I've tried:

  1. Computing a random permutation and its inverse separately using inbuilt NumPy functions
p = np.random.permutation(n)
pinv = np.argsort(p)
  1. The same idea as approach 1, but using the solution provided here. I found that this solution can speed up the computation of pinv by an order of magnitude.
def invert_permutation_numpy2(permutation):
    inv = np.empty_like(permutation)
    inv[permutation] = np.arange(len(inv), dtype=inv.dtype)
    return inv
p = np.random.permutation(n)
pinv = invert_permutation_numpy2(p)

I'm hoping that there is a solution that computes p and pinv jointly and yields additional speedup.

CodePudding user response:

The following is a straightforward implementation of the Fisher-Yates method (pseudocode from here). When compiled with numba it's faster than numpy:

import numpy as np
import numba

@numba.njit
def randperm(n):
  """Permuation of 1, 2, ... n and its inverse"""
  p = np.arange(1, n 1, dtype=np.int32) # initialize with identity permutation
  pinv = np.empty_like(p)
  for i in range(n-1, 0, -1):           # loop over all items except the first one
    z = np.random.randint(0, i 1)       
    temp = p[z]                         # swap numbers at i and z
    p[z] = p[i]
    p[i] = temp
    pinv[temp-1] = i                    # pinv[p[z]-1] = i
  pinv[p[0]-1] = 0  
  return p, pinv

Comparison:

%timeit p, pinv = randperm(100_000_000)
# 12.3 s ± 212 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%%timeit
p = np.random.permutation(np.arange(1, 100_000_000 1, dtype=np.int32))
pinv = np.argsort(p)
# 31.8 s ± 439 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

(NB np.random.permutation(n) give the permuatation of 0, 1, ... n-1 instead of 1, 2, ... n)

  • Related