I was solving a recursive problem in haskell, although I could get the solution I would like to cache outputs of sub problems since has over lapping sub-problem property.
The question is, given a grid of dimension n*m
, and an integer k
. How many ways are there to reach the gird (n, m) from (1, 1) with not more than k change of direction.
Here is the code without of memoization
paths :: Int -> Int -> Int -> Int -> Int -> Int -> Integer
paths i j n m k dir
| i > n || j > m || k < 0 = 0
| i == n && j == m = 1
| dir == 0 = paths (i 1) j n m k 1 paths i (j 1) n m k 2 -- is in grid (1,1)
| dir == 1 = paths (i 1) j n m k 1 paths i (j 1) n m (k-1) 2 -- down was the direction took to reach here
| dir == 2 = paths (i 1) j n m (k-1) 1 paths i (j 1) n m k 2 -- right was the direction took to reach here
| otherwise = -1
Here the dependant variables are i
, j
, k
, dir
. In languages like c /java a 4-d DP array could have been used (dp[n][m][k][3]
, in haskell I can't find a way to implement that.
CodePudding user response:
As I mentioned in a comment, "tying the knot" is a well-known technique for getting the GHC runtime to memoize results for you, if you know ahead of time all the values you will ever need to look up. The idea is to turn your recursive function into a self-referential data structure, and then simply look up the value you actually care about. I chose to use Array for this, but a Map would work as well. In either case, you must use a lazy / non-strict array/map, because we will be inserting values into it that we aren't ready to compute until the whole array is filled.
import Data.Array (array, bounds, inRange, (!))
paths :: Int -> Int -> Int -> Integer
paths m n k = go (1, 1, k, 0)
where go (i, j, k, dir)
| i == m && j == n = 1
| dir == 1 = get (i 1, j, k, 1) get (i, j 1, k-1, 2) -- down was the direction took to reach here
| dir == 2 = get (i 1, j, k-1, 1) get (i, j 1, k, 2) -- right was the direction took to reach here
| otherwise = get (i 1, j, k, 1) get (i, j 1, k, 2) -- is in grid (1,1)
a = array ((1, 1, 0, 1), (m, n, k, 2))
[(c, go c) | c <- (,,,) <$> [1..m] <*> [1..n] <*> [0..k] <*> [1..2]]
get x | inRange (bounds a) x = a ! x
| otherwise = 0
I simplified your API a bit:
- The
m
andn
parameters don't change with each iteration, so they shouldn't be part of the recursive call - The client shouldn't have to tell you what
i
,j
, anddir
start as, so they've been removed from the function signature and implicitly start at 1, 1, and 0 respectively - I also swapped the order of
m
andn
, because it's just weird to take ann
parameter first. This caused me quite a bit of headache, because I didn't notice for a while that I also needed to change the base case!
Then, as I said earlier, the idea is to fill up the array with all the recursive calls we'll need to make: that's the array
call. Notice the cells in array
are initialized with a call to go
, which (except for the base case!) involves calling get
, which involves looking up an element in the array. In this way, a
is self-referential or recursive. But we don't have to decide what order to look things up in, or what order to insert them in: we're sufficiently lazy that GHC evaluates the array elements as needed.
I've also been a bit cheeky by only making space in the array for dir=1
and dir=2
, not dir=0
. I get away with this because dir=0
only happens on the first call, and I can call go
directly for that case, bypassing the bounds-checking in get
. This trick does mean you'll get a runtime error if you pass an m
or n
less than 1, or a k
less than zero. You could add a guard for that to paths
itself, if you need to handle that case.
And of course, it does indeed work:
> paths 3 3 2
4
One other thing you could do would be to use a real data type for your direction, instead of an Int
:
import Data.Array (Ix, array, bounds, inRange, (!))
import Prelude hiding (Right)
data Direction = Neutral | Down | Right deriving (Eq, Ord, Ix)
paths :: Int -> Int -> Int -> Integer
paths m n k = go (1, 1, k, Neutral)
where go (i, j, k, dir)
| i == m && j == n = 1
| otherwise = case dir of
Neutral -> get (i 1, j, k, Down) get (i, j 1, k, Right)
Down -> get (i 1, j, k, Down) get (i, j 1, k-1, Right)
Right -> get (i 1, j, k-1, Down) get (i, j 1, k, Right)
a = array ((1, 1, 0, Down), (m, n, k, Right))
[(c, go c) | c <- (,,,) <$> [1..m] <*> [1..n] <*> [0..k] <*> [Down, Right]]
get x | inRange (bounds a) x = a ! x
| otherwise = 0
(I and J might be better names than Down and Right, I don't know if that's easier or harder to remember). I think this is probably an improvement, since the types have more meaning now, and you don't have this weird otherwise
clause that handles things like dir=7
which ought to be illegal. But it is still a bit wonky because it relies on the ordering of the enum values: it would break if we put Neutral
in between Down
and Right
. (I tried removing the Neutral
direction entirely and adding more special-casing for the first step, but this gets ugly in its own way)
CodePudding user response:
In Haskell these kinds of things aren't the most trivial ones, indeed. You would really like to have some in-place mutations going on to save up on memory and time, so I don't see any better way than equipping the frightening ST
monad.
This could be done over various data structures, arrays, vectors, repa tensors. I chose HashTable
from hashtables because it is the simplest to use and is performant enough to make sense in my example.
First of all, introduction:
{-# LANGUAGE Rank2Types #-}
module Solution where
import Control.Monad.ST
import Control.Monad
import Data.HashTable.ST.Basic as HT
Rank2Types
are useful when dealing with ST
, because of the phantom types. I picked the Basic
variant of the hashtable, because authors claim it has the fastest lookups --- and we are going to lookup a lot.
It is advised to use a type alias for the map, so here we go:
type Mem s = HT.HashTable s (Int, Int, Int, Int) Integer
ST-free entrypoint just to create the map and call our monster:
runpaths :: Int -> Int -> Int -> Int -> Int -> Int -> Integer
runpaths i j n m k dir = runST $ do
mem <- HT.new
paths mem i j n m k dir
Here is memorized computation of paths
. We just try to search for the result in the map, and if it is not there then we save it and return:
mempaths mem i j n m k dir = do
res <- HT.lookup mem (i, j, k, dir)
case res of
Just x -> return x
Nothing -> do
x <- paths mem i j n m k dir
HT.insert mem (i, j, k, dir) x
return x
And here goes the brain of the algorithm. It is just a monadic action that uses calls with memorization in place of plain recursion:
paths mem i j n m k dir
| i > n || j > m || k < 0 = return 0
| i == n && j == m = return 1
| dir == 0 = do
x1 <- mempaths mem (i 1) j n m k 1
x2 <- mempaths mem i (j 1) n m k 2 -- is in grid (1,1)
return $ x1 x2
| dir == 1 = do
x1 <- mempaths mem (i 1) j n m k 1
x2 <- mempaths mem i (j 1) n m (k-1) 2 -- down was the direction took to reach here
return $ x1 x2
| dir == 2 = do
x1 <- mempaths mem (i 1) j n m (k-1) 1
x2 <- mempaths mem i (j 1) n m k 2 -- right was the direction took to reach here
return $ x1 x2
| otherwise = return (-1)