Home > Enterprise >  Recursive "decorator" in Kotlin
Recursive "decorator" in Kotlin

Time:03-05

Suppose I have a recursive function like fibonacci:

fun fibonacci(n: Int): BigInteger = 
    if (n < 2) 
        n.toBigInteger() 
    else 
        fibonacci(n-1)   fibonacci(n-2)

This is slow because I'm recalculating known values a bunch of times. I can fix this by adding a "memo":

val memo = ConcurrentSkipListMap<Int, BigInteger>()

fun mFibonacci(n: Int): BigInteger = 
    memo.computeIfAbsent(n) { 
        if (n < 2) 
            it.toBigInteger() 
        else 
            mFibonacci(n-1)   mFibonacci(n-2) 
    }

Works like a charm, but can I do this without touching the function? My first thought was to use a wrapper class:

class Cached<in T, out R>(private val f: (T) -> R) : (T) -> R {
    private val cache = ConcurrentSkipListMap<T, R>()
    override fun invoke(x: T): R = cache.computeIfAbsent(x, f)
}

cFibonacci = Cached(::fibonacci)

... but the problem is, this only memoizes the outer-most call. If I call cFibonacci with a "big" number like 42, it takes a long time and then puts the correct value in the memo; subsequent calls with the 42 will be fast, but 41 will be slow again. Compare this to mFibonacci, which runs fast the first time, and populates the memo with values from 0 up to 42.

In Python, I can write a "decorator" which does this.

def memoized(f):
    def helper(t):
        if x in helper.memo:
          return helper.memo[t]
        else:
          r = f(t)
          helper.memo[t] = r
          return r
    helper.memo = {}
    return helper

@memoized
def fib(n):
  if n < 2:
    return n
  else:
    return fib(n-1)   fib(n-2)

This works just like mFibonacci above. I can also call it as fib = memoized(fib) if I imported fib from somewhere else and don't have access to the definition. Interestingly, c_fib = memoized(fib) works like Cached/cFibonacci above, hinting that maybe mutability of function reference is necessary.

The question is: (how) can I wrap/"decorate" a recursive function in a way that affects the inner calls in Kotlin the way I can in Python?

CodePudding user response:

I'll suggest a workaround in the absence of a solution. This pattern requires access to the definition of the function (i.e. it can't be an import):

object fibonacci: (Int) -> BigInteger {
    private val memo = ConcurrentSkipListMap<Int, BigInteger>()
    override fun invoke(n: Int): BigInteger = fibonacci(n)
    private fun fibonacci(n: Int): BigInteger = memo.computeIfAbsent(n) {
        if (n < 2)
            n.toBigInteger()
        else
            fibonacci(n-1)   fibonacci(n-2)
    }
}

There are few decisions here that may need justification:

  1. I'm using a camelCase name instead of PascalCase. Despite the fact that object is a class, it's being called as a function and so I feel the naming convention of a function is better. With this, you can call the fibonacci function exactly as you normally would.

  2. I've renamed invoke to fibonacci. Without this, the recursive calls use invoke which seems less readable to me. With this, you can read and write the fibonacci function (almost*) exactly as you normally would.

In general, the idea is to be the least intrusive as possible while still adding the desired functionality. I'm open to suggestions on how to improve it though!

*Something to note is that the function is defined using lambda syntax, so there is no return. If you have a single return at the end of the function, you just remove the return keyword. If you have multiple returns, you'll have to use the less-than-beautiful return@computeIfAbsent for the short-circuits.

CodePudding user response:

As requested, this is a variation of @AlexJones's workaround that doesn't wrap the function in an object with unconventional naming. I did not test this--it's based on the assumption that the other solution works. The following code would be at the top level of a .kt file.

private val memo = ConcurrentSkipListMap<Int, BigInteger>()

fun fibonacci(n: Int): BigInteger = fibonacciImpl(n)

private fun fibonacciImpl(n: Int): BigInteger = memo.computeIfAbsent(n) {
    if (n < 2)
        n.toBigInteger()
    else
        fibonacci(n-1)   fibonacci(n-2)
}
  • Related