I understand how to cover three terms (while n is either 0, 1 or 2) as base case but how can I have my function return three recursive calls?
I'm after a recursive thrib(n) function that on input n, returns thrib(n)
thrib(0) = 1
thrib(1) = 1
thrib(2) = 1
# if n > 2:
thrib(n) = thrib(n-1) 2 thrib(n-2) 3 thrib(n-3)
Similar to the code below but this would be different of course
def fib(n):
if n <=1:
return n
return fib(n-1) fib(n-2)
print(fib(2))
print(fib(10))
print(fib(30))
print(fib(40))
print(fib(50))
The purpose of this function would be to later create functions that use memoisation and iteration
CodePudding user response:
def trib(n):
if n < 3:
return 1
else:
return trib(n - 1) trib(n - 2) trib(n - 3)
for x in range(10):
print("trib(%d) == %d" % (x, trib(x)))
trib(0) == 1
trib(1) == 1
trib(2) == 1
trib(3) == 3
trib(4) == 5
trib(5) == 9
trib(6) == 17
trib(7) == 31
trib(8) == 57
trib(9) == 105
real 0m0.034s <-
user 0m0.025s
sys 0m0.008s
for x
greater than 20, things start to slow down considerably -
for x in range(20,30):
print("trib(%d) == %d" % (x, trib(x)))
trib(20) == 85525
trib(21) == 157305
trib(22) == 289329
trib(23) == 532159
trib(24) == 978793
trib(25) == 1800281
trib(26) == 3311233
trib(27) == 6090307
trib(28) == 11201821
trib(29) == 20603361
real 0m30.229s <-
user 0m13.072s
sys 0m0.065s
Adding functools.lru_cache to the function does exactly what you want -
from functools import lru_cache # <-
@lru_cache # <-
def trib(n):
if n < 3:
return 1
else:
return trib(n - 1) trib(n - 2) trib(n - 3)
trib(20) == 85525
trib(21) == 157305
trib(22) == 289329
trib(23) == 532159
trib(24) == 978793
trib(25) == 1800281
trib(26) == 3311233
trib(27) == 6090307
trib(28) == 11201821
trib(29) == 20603361
real 0m0.041s
user 0m0.033s
sys 0m0.008s
Now we can compute huge numbers, like x = 500
in a fraction of a second -
trib(500) == 920080768385554537118362247382795511748094060211249395714846663504876568933187764773247307152478284645110957568578400676655207195125
trib(501) == 1692292371018818732848588652741948525682146322571135018598482022598345259010668154715792046832331419615682051662747684749969421243929
trib(502) == 3112610943964882401860542627651146026612901424988062320853721075554798348058448628208346771618611840323992720078687746154459656020977
trib(503) == 5724984083369255671827493527775890064043141807770446735167049761658020176002304547697386125603421544584785729310013831581084284460031
trib(504) == 10529887398352956806536624808168984616338189555329644074619252859811163783071421330621524944054364804524460501051449262485513361724937
trib(505) == 19367482425687094880224660963596020706994232788088153130640023697023982307132174506527257841276398189433238950440150840221057302205945
trib(506) == 35622353907409307358588779299540895387375564151188243940426326318493166266205900384846168910934184538542485180801613934287654948390913
trib(507) == 65519723731449359045350065071305900710707986494606041145685602875328312356409496221994951696264947532500184632293214036994225612321795
trib(508) == 120509560064545761284163505334442816805077783433882438216751952890845460929747571113368378448475530260475908763534978811502937862918653
trib(509) == 221651637703404427688102349705289612903161334079676723302863882084666939552362967720209499055674662331518578576629806782784818423631361
real 0m0.066s # <-
user 0m0.032s
sys 0m0.008s
without lru_cache
If you want to understand how lru_cache
works, we can model it similarly using a dict
-
mem = dict()
def trib(n):
# memo read
if n in mem: return mem[n]
# memo write
if n < 3:
r = 1
else:
r = trib(n - 1) trib(n - 2) trib(n - 3)
mem[n] = r
return r
CodePudding user response:
You can implement it with multiple recursive calls:
def trib(n): return 1 if n<3 else trib(n-1) trib(n-2) trib(n-3)
or with a single recursive call:
def trib(n,a=1,b=1,c=1): return trib(n-1,b,c,a b c) if n else a
or iteratively:
def trib(n):
a = b = c = 1
for _ in range(n):
a,b,c = b,c,a b c
return a