Home > Net >  how to write optimal nested recursion functions
how to write optimal nested recursion functions

Time:09-30

I'm trying to translate a complicated excel spreadsheet to python, here's a simple example of what the excel spreadsheet looks like:

r       a               b          c            d
1       2%             a(r) 1   b(r)*2 a(r)   100
2to100  d(r-1)*.0001   a(r) 1   b(r)*2 a(r)   d(r-1)*c(r)

Here's what I wrote:

def a(r):
  if r==1:
    return 0.02
  else:
    return d(r-1)*0.0001
def b(r):
  return a(r) 1
def c(r):
  return b(r)*2 a(r)
def d(r):
  if r==1:
    return 100
  else:
    return d(r-1)*c(r)

I understand this is a naive example and I don't need a, b and c and can combine everything into one recursive function, but the real spreadsheet is more complicated and it takes a lot of effort to combine them.

Now my question is with efficiency. if I write the recursive functions separately, function (a) is being called 3 times each step and will quickly get out of hand. Is there a way to write the function so (a) is only called once each step?


calling d(100) is equivalent of calling function (a) 3^100=5.1*10^47 times and is impossible to finish.


Thanks to Tuqay in the comment, adding the following code solved the problem and speed up d(100) to less than a second

from functools import lru_cache
@lru_cache(maxsize = 128)
def a(r):
  if r==1:
    return 0.02
  else:
    return d(r-1)*0.0001
def b(r):
  return a(r) 1
def c(r):
  return b(r)*2 a(r)
def d(r):
  if r==1:
    return 100
  else:
    return d(r-1)*c(r)/3

CodePudding user response:

Your function is not recursive in the sense of "keep calling myself until some condition is met," as it only refers to values in the current or previous iteration. I'd call it self-referential but not recursive or tail-recursive.

So if you simply keep track of the value of d(r-1) from previous iteration, and compute a(r) once per iteration, then you should be good to go.

CodePudding user response:

Thanks to Tuqay in the comment, adding the following code solved the problem and speed up d(100) to less than a second

from functools import lru_cache
@lru_cache(maxsize = 128)
def a(r):
  if r==1:
    return 0.02
  else:
    return d(r-1)*0.0001
def b(r):
  return a(r) 1
def c(r):
  return b(r)*2 a(r)
def d(r):
  if r==1:
    return 100
  else:
    return d(r-1)*c(r)/3
  • Related