I want to implement a function representing a while
loop using the State
monad from cats
.
I did it like this:
def whileLoopState[S](cond: S => Boolean)(block: S => S): State[S, Unit] = State { state =>
if (cond(state)) {
val nextState = block(state)
whileLoopState(cond)(block).run(nextState).value
} else {
(state, ())
}
}
The problem with this implementation is that it's not stack safe because the recursive call is not in tail position, so the following results in stack overflow error:
whileLoopState[Int](s => s > 0) { s =>
println(s)
s - 1
}.run(10000).value
Cats has tailRecM
method implemented for every instance of Monad
trait
that allows to make monadic recursive functions stack safe:
type WhileLoopState[A] = State[Unit, A]
def whileLoopStateTailRec[S](cond: S => Boolean)(block: S => S)(initialState: S): WhileLoopState[S] = Monad[WhileLoopState]
.tailRecM(initialState) { newState =>
State { _ =>
if (cond(newState)) {
val nextState = block(newState)
((), Left(nextState))
} else {
((), Right(newState))
}
}
}
Now this works:
whileLoopStateTailRec[Int](s => s > 0) { s =>
println(s)
s - 1
} (10000).run().value
but the implementation of whileLoopStateTailRec
seems too convoluted
for a simple case like this and therefore raises my suspicion that I'm not doing things correctly.
Is there a way to simplify it?
Is it possible to use State[A, Unit]
instead of State[Unit, A]
so that the state is kept in the proper slot?
Is it possible to make recursive function using State
monad stack safe without using tailRecM
?
CodePudding user response:
You can either just take advantage that flatMap
on State
is stack safe like this:
def whileLoopState[S](cond: S => Boolean)(block: S => S): State[S, Unit] =
State.get[S].flatMap { s =>
if (cond(s)) State.set(block(s)) >> whileLoopState(cond)(block)
else State.pure(())
}
Or, even better, just reuse existing syntax:
def whileLoopState[S](cond: S => Boolean)(block: S => S): State[S, Unit] =
State.modify(block).whileM_(State.inspect(cond))
You can see the code running here.