Suppose I have a recursive function like fibonacci
:
fun fibonacci(n: Int): BigInteger =
if (n < 2)
n.toBigInteger()
else
fibonacci(n-1) fibonacci(n-2)
This is slow because I'm recalculating known values a bunch of times. I can fix this by adding a "memo":
val memo = ConcurrentSkipListMap<Int, BigInteger>()
fun mFibonacci(n: Int): BigInteger =
memo.computeIfAbsent(n) {
if (n < 2)
it.toBigInteger()
else
mFibonacci(n-1) mFibonacci(n-2)
}
Works like a charm, but can I do this without touching the function? My first thought was to use a wrapper class:
class Cached<in T, out R>(private val f: (T) -> R) : (T) -> R {
private val cache = ConcurrentSkipListMap<T, R>()
override fun invoke(x: T): R = cache.computeIfAbsent(x, f)
}
cFibonacci = Cached(::fibonacci)
... but the problem is, this only memoizes the outer-most call. If I call cFibonacci
with a "big" number like 42
, it takes a long time and then puts the correct value in the memo; subsequent calls with the 42
will be fast, but 41
will be slow again. Compare this to mFibonacci
, which runs fast the first time, and populates the memo with values from 0
up to 42
.
In Python, I can write a "decorator" which does this.
def memoized(f):
def helper(t):
if x in helper.memo:
return helper.memo[t]
else:
r = f(t)
helper.memo[t] = r
return r
helper.memo = {}
return helper
@memoized
def fib(n):
if n < 2:
return n
else:
return fib(n-1) fib(n-2)
This works just like mFibonacci
above. I can also call it as fib = memoized(fib)
if I imported fib
from somewhere else and don't have access to the definition. Interestingly, c_fib = memoized(fib)
works like Cached
/cFibonacci
above, hinting that maybe mutability of function reference is necessary.
The question is: (how) can I wrap/"decorate" a recursive function in a way that affects the inner calls in Kotlin the way I can in Python?
CodePudding user response:
I'll suggest a workaround in the absence of a solution. This pattern requires access to the definition of the function (i.e. it can't be an import):
object fibonacci: (Int) -> BigInteger {
private val memo = ConcurrentSkipListMap<Int, BigInteger>()
override fun invoke(n: Int): BigInteger = fibonacci(n)
private fun fibonacci(n: Int): BigInteger = memo.computeIfAbsent(n) {
if (n < 2)
n.toBigInteger()
else
fibonacci(n-1) fibonacci(n-2)
}
}
There are few decisions here that may need justification:
I'm using a camelCase name instead of PascalCase. Despite the fact that
object
is a class, it's being called as a function and so I feel the naming convention of a function is better. With this, you can call thefibonacci
function exactly as you normally would.I've renamed
invoke
tofibonacci
. Without this, the recursive calls useinvoke
which seems less readable to me. With this, you can read and write thefibonacci
function (almost*) exactly as you normally would.
In general, the idea is to be the least intrusive as possible while still adding the desired functionality. I'm open to suggestions on how to improve it though!
*Something to note is that the function is defined using lambda syntax, so there is no return
. If you have a single return at the end of the function, you just remove the return
keyword. If you have multiple returns, you'll have to use the less-than-beautiful return@computeIfAbsent
for the short-circuits.
CodePudding user response:
As requested, this is a variation of @AlexJones's workaround that doesn't wrap the function in an object
with unconventional naming. I did not test this--it's based on the assumption that the other solution works. The following code would be at the top level of a .kt file.
private val memo = ConcurrentSkipListMap<Int, BigInteger>()
fun fibonacci(n: Int): BigInteger = fibonacciImpl(n)
private fun fibonacciImpl(n: Int): BigInteger = memo.computeIfAbsent(n) {
if (n < 2)
n.toBigInteger()
else
fibonacci(n-1) fibonacci(n-2)
}