Home > Software engineering >  Convert a recursive python code to a non-recursive version
Convert a recursive python code to a non-recursive version

Time:02-18

The code provided here works unless we start to increase the distinct and n-symbols and length, for example, on my computer n_symbols=512, length=512, distinct=300 ends up with this error RecursionError: maximum recursion depth exceeded in comparison and then overflow errors if I increase the lru_cache value.
What I want is to have a non-recursive version of this code.

from functools import lru_cache
@lru_cache
def get_permutations_count(n_symbols, length, distinct, used=0):
    '''
     - n_symbols: number of symbols in the alphabet
     - length: the number of symbols in each sequence
     - distinct: the number of distinct symbols in each sequence
    '''
    if distinct < 0:
        return 0
    if length == 0:
        return 1 if distinct == 0 else 0
    else:
        return \
          get_permutations_count(n_symbols, length-1, distinct-0, used 0) * used   \
          get_permutations_count(n_symbols, length-1, distinct-1, used 1) * (n_symbols - used)

Then

get_permutations_count(n_symbols=300, length=300, distinct=270)

runs in ~0.5 second giving the answer

2729511887951350984580070745513114266766906881300774347439917775
7093985721949669285469996223829969654724957176705978029888262889
8157939885553971500652353177628564896814078569667364402373549268
5524290993833663948683375995196081654415976659499171897405039547
1546236260377859451955180752885715923847446106509971875543496023
2494854876774756172488117802642800540206851318332940739395445903
6305051887120804168979339693187702655904071331731936748927759927
3688881301614948043182289382736687065840703041231428800720854767
0713406956719647313048146023960093662879015837313428567467555885
3564982943420444850950866922223974844727296000000000000000000000
000000000000000000000000000000000000000000000000

CodePudding user response:

Here's mine:

def get_permutations_count_improved(n_symbols, length, distinct):
    if distinct > length or distinct > n_symbols:
        return 0
    ways = [1]
    for _ in range(length):
        ways = [used * (distinct - d)   new
               for d, used, new in zip(range(distinct 1), [*ways, 0], [0, *ways])]
    return ways[distinct] * comb(n_symbols, distinct) * factorial(distinct)

Speed comparison for some argument sets:

n_symbols length distinct   yours    mine
   300      300    270      0.62 s   0.012 s (~51 times faster)
   512      512    300        -      0.035 s
  1024     1024    600        -      0.22 s
  3000     3000   2700        -      6.0 s

In my last line you see I split the overall result into three factors:

  • comb(n_symbols, distinct) for choosing which distinct out of the n_symbols symbols actually get used. That essentially gets rid of the n_symbols parameter, or think of it as compensating setting n_symbols = distinct.
  • factorial(distinct) for the order in which the symbols get used first. This gets rid of the * (n_symbols - used) in your recurrence.
  • ways[distinct] is the number of ways to build a sequence of length length with exactly distinct distinct symbols, where the order in which they get used first is fixed.

It might be easier to think of the ways table as two-dimensional: ways[length][distinct]. But for more memory-efficiency, I compute it row by row and only keep the latest row.

Benchmark and some correctness checks (Try it online!):

from timeit import timeit
from functools import lru_cache
from math import comb, factorial

@lru_cache
def get_permutations_count(n_symbols, length, distinct, used=0):
    '''
     - n_symbols: number of symbols in the alphabet
     - length: the number of symbols in each sequence
     - distinct: the number of distinct symbols in each sequence
    '''
    if distinct < 0:
        return 0
    if length == 0:
        return 1 if distinct == 0 else 0
    else:
        return \
          get_permutations_count(n_symbols, length-1, distinct-0, used 0) * used   \
          get_permutations_count(n_symbols, length-1, distinct-1, used 1) * (n_symbols - used)

def get_permutations_count_improved(n_symbols, length, distinct):
    if distinct > length or distinct > n_symbols:
        return 0
    ways = [1]
    for _ in range(length):
        ways = [used * (distinct - d)   new
               for d, used, new in zip(range(distinct 1), [*ways, 0], [0, *ways])]
    return ways[distinct] * comb(n_symbols, distinct) * factorial(distinct)

funcs = get_permutations_count, get_permutations_count_improved

# Check correctness
stop = 20
for a in range(stop):
    for b in range(stop):
        for c in range(stop):
            expect = get_permutations_count(a, b, c)
            result = get_permutations_count_improved(a, b, c)
            assert result == expect, (a, b, c, expect, result)

# Benchmark
n_symbols, length, distinct = 300, 300, 270
#n_symbols, length, distinct = 512, 512, 300
#n_symbols, length, distinct = 1024, 1024, 600
#n_symbols, length, distinct = 3000, 3000, 2700
for func in funcs[0:] * 3:
    funcs[0].cache_clear()
    t = timeit(lambda: func(n_symbols, length, distinct), number=1)
    print('%.3f seconds ' % t, func.__name__)
  • Related