Home > OS >  Implementing recursive, stack safe function using State Monad
Implementing recursive, stack safe function using State Monad

Time:11-07

I want to implement a function representing a while loop using the State monad from cats.
I did it like this:

def whileLoopState[S](cond: S => Boolean)(block: S => S): State[S, Unit] = State { state =>
  if (cond(state)) {
    val nextState = block(state)
    whileLoopState(cond)(block).run(nextState).value
  } else {
    (state, ())
  }
}

The problem with this implementation is that it's not stack safe because the recursive call is not in tail position, so the following results in stack overflow error:

whileLoopState[Int](s => s > 0) { s =>
  println(s)
  s - 1
}.run(10000).value

Cats has tailRecM method implemented for every instance of Monad trait that allows to make monadic recursive functions stack safe:

type WhileLoopState[A] = State[Unit, A]

def whileLoopStateTailRec[S](cond: S => Boolean)(block: S => S)(initialState: S): WhileLoopState[S] = Monad[WhileLoopState]
  .tailRecM(initialState) { newState =>
    State { _ =>
      if (cond(newState)) {
        val nextState = block(newState)
        ((), Left(nextState))
      } else {
        ((), Right(newState))
      }
    }
  }

Now this works:

whileLoopStateTailRec[Int](s => s > 0) { s =>
  println(s)
  s - 1
} (10000).run().value

but the implementation of whileLoopStateTailRec seems too convoluted for a simple case like this and therefore raises my suspicion that I'm not doing things correctly.

Is there a way to simplify it?
Is it possible to use State[A, Unit] instead of State[Unit, A] so that the state is kept in the proper slot?
Is it possible to make recursive function using State monad stack safe without using tailRecM?

CodePudding user response:

You can either just take advantage that flatMap on State is stack safe like this:

def whileLoopState[S](cond: S => Boolean)(block: S => S): State[S, Unit] = 
  State.get[S].flatMap { s =>
    if (cond(s)) State.set(block(s)) >> whileLoopState(cond)(block)
    else State.pure(())
  }

Or, even better, just reuse existing syntax:


def whileLoopState[S](cond: S => Boolean)(block: S => S): State[S, Unit] =
  State.modify(block).whileM_(State.inspect(cond))

You can see the code running here.

  • Related