Home > Software design >  Recursive thribonacci function in python
Recursive thribonacci function in python

Time:12-12

I understand how to cover three terms (while n is either 0, 1 or 2) as base case but how can I have my function return three recursive calls?

I'm after a recursive thrib(n) function that on input n, returns thrib(n)

thrib(0) = 1
thrib(1) = 1
thrib(2) = 1
# if n > 2:
thrib(n) = thrib(n-1)   2 thrib(n-2)   3 thrib(n-3)

Similar to the code below but this would be different of course

def fib(n):
    if n <=1:
        return n
    return fib(n-1) fib(n-2)


print(fib(2))
print(fib(10))
print(fib(30))
print(fib(40))
print(fib(50))

The purpose of this function would be to later create functions that use memoisation and iteration

CodePudding user response:

def trib(n):
  if n < 3:
    return 1
  else:
    return trib(n - 1)   trib(n - 2)   trib(n - 3)
for x in range(10):
  print("trib(%d) == %d" % (x, trib(x)))
trib(0) == 1
trib(1) == 1
trib(2) == 1
trib(3) == 3
trib(4) == 5
trib(5) == 9
trib(6) == 17
trib(7) == 31
trib(8) == 57
trib(9) == 105
real    0m0.034s     <-
user    0m0.025s
sys 0m0.008s

for x greater than 20, things start to slow down considerably -

for x in range(20,30):
  print("trib(%d) == %d" % (x, trib(x)))
trib(20) == 85525
trib(21) == 157305
trib(22) == 289329
trib(23) == 532159
trib(24) == 978793
trib(25) == 1800281
trib(26) == 3311233
trib(27) == 6090307
trib(28) == 11201821
trib(29) == 20603361
real    0m30.229s            <-
user    0m13.072s
sys 0m0.065s

Adding functools.lru_cache to the function does exactly what you want -

from functools import lru_cache   # <-

@lru_cache                        # <-
def trib(n):
  if n < 3:
    return 1
  else:
    return trib(n - 1)   trib(n - 2)   trib(n - 3)
trib(20) == 85525
trib(21) == 157305
trib(22) == 289329
trib(23) == 532159
trib(24) == 978793
trib(25) == 1800281
trib(26) == 3311233
trib(27) == 6090307
trib(28) == 11201821
trib(29) == 20603361
real    0m0.041s
user    0m0.033s
sys 0m0.008s

Now we can compute huge numbers, like x = 500 in a fraction of a second -

trib(500) == 920080768385554537118362247382795511748094060211249395714846663504876568933187764773247307152478284645110957568578400676655207195125
trib(501) == 1692292371018818732848588652741948525682146322571135018598482022598345259010668154715792046832331419615682051662747684749969421243929
trib(502) == 3112610943964882401860542627651146026612901424988062320853721075554798348058448628208346771618611840323992720078687746154459656020977
trib(503) == 5724984083369255671827493527775890064043141807770446735167049761658020176002304547697386125603421544584785729310013831581084284460031
trib(504) == 10529887398352956806536624808168984616338189555329644074619252859811163783071421330621524944054364804524460501051449262485513361724937
trib(505) == 19367482425687094880224660963596020706994232788088153130640023697023982307132174506527257841276398189433238950440150840221057302205945
trib(506) == 35622353907409307358588779299540895387375564151188243940426326318493166266205900384846168910934184538542485180801613934287654948390913
trib(507) == 65519723731449359045350065071305900710707986494606041145685602875328312356409496221994951696264947532500184632293214036994225612321795
trib(508) == 120509560064545761284163505334442816805077783433882438216751952890845460929747571113368378448475530260475908763534978811502937862918653
trib(509) == 221651637703404427688102349705289612903161334079676723302863882084666939552362967720209499055674662331518578576629806782784818423631361
real    0m0.066s    # <-
user    0m0.032s
sys 0m0.008s

without lru_cache

If you want to understand how lru_cache works, we can model it similarly using a dict -

mem = dict()

def trib(n):
  # memo read
  if n in mem: return mem[n]
  # memo write
  if n < 3:
    r = 1
  else:
    r = trib(n - 1)   trib(n - 2)   trib(n - 3)
  mem[n] = r
  return r

CodePudding user response:

You can implement it with multiple recursive calls:

def trib(n): return 1 if n<3 else trib(n-1) trib(n-2) trib(n-3)

or with a single recursive call:

def trib(n,a=1,b=1,c=1): return trib(n-1,b,c,a b c) if n else a

or iteratively:

def trib(n):
    a = b = c = 1
    for _ in range(n):
        a,b,c = b,c,a b c
    return a
  • Related