Home > OS >  Refactor impure recursion with state monad?
Refactor impure recursion with state monad?

Time:12-31

I've been dissecting this one-liner solution for aoc day 14 and came across an elegant impure recursive solution:

def s(x,y):
    if y>h: return True
    if (x,y) in m:
        return False
    return next((r for d in (0,-1,1) if (r:=s(x d,y 1))), None) or m.add((x,y))

full solution on godbolt

One way you could make this pure is by explicitly passing and returning the set m from the function s (i.e. s :: int -> int -> set -> (bool, set)).

However, I've also read about how the reader/writer/state monads save you from having to pass the extra parameter and handle the tuple result an am interested in porting this recursion to haskell.

I found a haskell solution on the reddit that looks like it may do the same recursion (as well as two more that don't).

fill :: (MArray a Bool (ST s), Ix i, Num i, Show i) => a (i, i) Bool -> i -> ST s (Int, Int)
fill blocks maxY = do
    counterAtMaxY <- newSTRef Nothing
    counter <- newSTRef 0
    let fill' (x, y) = readArray blocks (x, y) >>= flip bool (pure ()) do
            when (y == maxY) $ readSTRef counterAtMaxY >>= maybe
                (readSTRef counter >>= writeSTRef counterAtMaxY . Just) (const $ pure ())
            when (y <= maxY) $ fill' (x, y   1) >> fill' (x - 1, y   1) >> fill' (x   1, y   1)
            writeArray blocks (x, y) True >> modifySTRef' counter (  1)
    fill' (500, 0)
    counterAtMaxY <- readSTRef counterAtMaxY
    counter <- readSTRef counter
    pure (fromMaybe counter counterAtMaxY, counter)

full solution on godbolt

Could someone confirm that this indeed is a port of the python solution. If so could they baby me through following how the recursion is happening?

I still am not Haskell literate. I can kind of make out that fill' (500, 0) means m >>= \_ -> fill' (500, 0), which means discard the current state, and create a new monad independently (something gets preserved but I'm confused what)??. I also don't understand monad transformers at all.

The Haskell solution does part 2 of the question simultaneously, so maybe someone can factor that out so there's no confusion between the cartesian coordinates and the pair of ints containing the solution.

CodePudding user response:

(partial answer only: below I only clarify the Haskell code, but I did not compare it against the Python code or the task)

I kind of dislike the Haskell code. I believe it aims at pointfree style too much. I mean, flip bool to avoid a variable? maybe between two complex branches? Nah, I'd use use a plain if/case.

Still, it's working in the ST s monad, so it reads like imperative code. The state is never discarded, only modified as in imperative languages. Read a >>= b roughly like result = a() ; b(result) except the code stubbornly avoids to introduce the variable result.

Here is the code, rewritten to (IMHO) improve readability.

fill :: (MArray a Bool (ST s), Ix i, Num i, Show i) 
     => a (i, i) Bool -> i -> ST s (Int, Int)
fill blocks maxY = do
    counterAtMaxY <- newSTRef Nothing
    counter <- newSTRef 0
    let fill' (x, y) = do
           -- test the boolean flag at (x,y)
           b <- readArray blocks (x, y)
           -- if false, we did not visit (x,y) before
           when (not b) $ do
              when (y == maxY) $ do
                 -- if counterAtMaxY is Nothing,
                 -- replace it with counter
                 mayc <- readSTRef counterAtMaxY
                 case mayc of
                    Nothing -> do
                       c <- readSTRef counter
                       writeSTRef counterAtMaxY (Just c)
                    Just _ -> pure ()
              when (y <= maxY) $ do
                 -- recurse thrice
                 fill' (x, y   1)
                 fill' (x - 1, y   1)
                 fill' (x   1, y   1)
            -- mark (x,y) as visited
            writeArray blocks (x, y) True
            -- increment counter
            modifySTRef' counter (  1)
    fill' (500, 0)
    counterAtMaxY <- readSTRef counterAtMaxY
    counter <- readSTRef counter
    pure (fromMaybe counter counterAtMaxY, counter)

I'd also rewrite the y == maxY case as follows:

when (y == maxY) $ do
   -- if counterAtMaxY is Nothing,
   -- replace it with counter
   c <- readSTRef counter       
   modify counterAtMaxY $ \mayc ->
      case mayc of
         Nothing -> Just c
         Just _ -> mayc

or even

when (y == maxY) $ do
   -- if counterAtMaxY is Nothing,
   -- replace it with counter
   c <- readSTRef counter       
   modify counterAtMaxY (<|> Just c)

where (<|> Just c) is a function acting as the identity on Just _ values, but mapping Nothing to Just c, which is similar to the Python ... or c.

CodePudding user response:

Below is a fairly close translation of your Python code to Haskell. Some remarks on the differences:

  • The global h becomes a local parameter, and m :: Set (Int, Int) gets passed implicitly in the State monad, accessed using get and modify.
  • There is no early return in Haskell (calling return/pure doesn't abort the rest of the function, you have to put it at the end of the block). On the other hand, if expressions must have an else clause, so that forces you to do the right thing anyway.
  • The generator expression can be written as a higher-order function which tries each action in a list, stopping as soon as one returns True.
  • the add function in Python returns None, which gets interpreted as False in conditionals. In Haskell we don't like this kind of overloading; instead, we explicitly attach the False value to the value-less action add (x, y), add (x, y) *> pure False.
  • Use execState to "run the monadic program" s h0 500 0 with an initial state m0, obtaining its final state. That "program" s h0 500 0 :: M Bool is actually a pure function Set (Int, Int) -> (Bool, Set (Int, Int)), and all execState does is to apply that to the initial state and project out the second component of the output pair. The point of the "state monad" is that such a function can be defined with the syntax of an imperative language ("do-notation").
module Main where

import Control.Monad.State
import Data.Set (Set)
import qualified Data.Set as Set

type M = State (Set (Int, Int))

s :: Int -> Int -> Int -> M Bool
s h x y =
  if y > h then pure True
  else do
    m <- get
    if Set.member (x, y) m then
      pure False
    else
      orM ([s h (x d) (y 1) | d <- [0, -1, 1]]    [add (x, y) *> pure False])

orM :: Monad m => [m Bool] -> m Bool
orM [] = pure False
orM (x : xs) = do
  b <- x
  if b then pure True
  else orM xs

add :: (Int, Int) -> M ()
add (x, y) = modify (Set.insert (x, y))

-- Example from https://adventofcode.com/2022/day/14

m0 :: Set (Int, Int)
m0 = vline 498 4 6 <> hline 498 496 6 <> hline 503 502 4 <> vline 502 4 9 <> hline 494 502 9

vline, hline :: Int -> Int -> Int -> Set (Int, Int)
vline x y1 y2 | y1 > y2 = vline x y2 y1
vline x y1 y2 = Set.fromList [(x, y) | y <- [y1 .. y2]]

hline x1 x2 y | x1 > x2 = hline x2 x1 y
hline x1 x2 y = Set.fromList [(x, y) | x <- [x1 .. x2]]

h0 :: Int
h0 = 9

main :: IO ()
main =
  print (Set.size (execState (s h0 500 0) m0) - Set.size m0)
  -- Output: 24
  • Related