Home > Mobile >  How to implement memoization in Scala without mutability?
How to implement memoization in Scala without mutability?

Time:03-21

I was recently reading Category Theory for Programmers and in one of the challenges, Bartosz proposed to write a function called memoize which takes a function as an argument and returns the same one with the difference that, the first time this new function is called, it stores the result of the argument and then returns this result each time it is called again.

def memoize[A, B](f: A => B): A => B = ???

The problem is, I can't think of any way to implement this function without resorting to mutability. Moreover, the implementations I have seen uses mutable data structures to accomplish the task.

My question is, is there a purely functional way of accomplishing this? Maybe without mutability or by using some functional trick?

Thanks for reading my question and for any future help. Have a nice day!

CodePudding user response:

is there a purely functional way of accomplishing this?

No. Not in the narrowest sense of pure functions and using the given signature.

TLDR: Use mutable collections, it's okay!

Impurity of g

val g = memoize(f)
// state 1
g(a)
// state 2

What would you expect to happen for the call g(a)?

If g(a) memoizes the result, an (internal) state has to change, so the state is different after the call g(a) than before. As this could be observed from the outside, the call to g has side effects, which makes your program impure.

From the Book you referenced, 2.5 Pure and Dirty Functions:

[...] functions that

  • always produce the same result given the same input and
  • have no side effects

are called pure functions.

Is this really a side effect?

Normally, at least in Scala, internal state changes are not considered side effects.

See the definition in the Scala Book

A pure function is a function that depends only on its declared inputs and its internal algorithm to produce its output. It does not read any other values from “the outside world” — the world outside of the function’s scope — and it does not modify any values in the outside world.

The following examples of lazy computations both change their internal states, but are normally still considered purely functional as they always yield the same result and have no side effects apart from internal state:

lazy val x = 1
// state 1: x is not computed
x
// state 2: x is 1
val ll = LazyList.continually(0)
// state 1: ll = LazyList(<not computed>)
ll(0)
// state 2: ll = LazyList(0, <not computed>)

In your case, the equivalent would be something using a private, mutable Map (as the implementations you may have found) like:

def memoize[A, B](f: A => B): A => B = {
  val cache = mutable.Map.empty[A, B]
  (a: A) => cache.getOrElseUpdate(a, f(a))
}

Note that the cache is not public. So, for a pure function f and without looking at memory consumption, timings, reflection or other evil stuff, you won't be able to tell from the outside whether f was called twice or g cached the result of f.

In this sense, side effects are only things like printing output, writing to public variables, files etc.

Thus, this implementation is considered pure (at least in Scala).

Avoiding mutable collections

If you really want to avoid var and mutable collections, you need to change the signature of your memoize method. This is, because if g cannot change internal state, it won't be able to memoize anything new after it was initialized.

An (inefficient but simple) example would be

def memoizeOneValue[A, B](f: A => B)(a: A): (B, A => B) = {
  val b = f(a)
  val g = (v: A) => if (v == a) b else f(v)
  (b, g)
}

val (b1, g) = memoizeOneValue(f, a1)
val (b2, h) = memoizeOneValue(g, a2)
// ...

The result of f(a1) would be cached in g, but nothing else. Then, you could chain this and always get a new function.

If you are interested in a faster version of that, see @esse's answer, which does the same, but more efficient (using an immutable map, so O(log(n)) instead of the linked list of functions above, O(n)).

CodePudding user response:

Let's try(Note: I have change the return type of memoize to store the cached data):

import scala.language.existentials

type M[A, B] = A => T forSome { type T <: (B, A => T) }

def memoize[A, B](f: A => B): M[A, B] = {
  import scala.collection.immutable
  
  def withCache(cache: immutable.Map[A, B]): M[A, B] = a => cache.get(a) match {
    case Some(b) => (b, withCache(cache))
    case None    =>
      val b = f(a)
      (b, withCache(cache   (a -> b)))
  }
  withCache(immutable.Map.empty)
}


def f(i: Int): Int = { print(s"Invoke f($i)"); i }


val (i0, m0) = memoize(f)(1)    // f only invoked at first time
val (i1, m1) = m0(1)
val (i2, m2) = m1(1)
  • Related