Home > Back-end >  Haskell - How do I write an idiomatic and efficient loop?
Haskell - How do I write an idiomatic and efficient loop?

Time:10-24

I'm relatively new to Haskell, so I am working through some old Advent of Code problems to familiarize myself with the language.

However, I got stuck on 2017 day 17, part two. I've tried three solutions to this problem.

(Edit: reduced the code block to a clearer example)

The following solution is something I would expect to work moderately efficiently:

run :: IO()
run = do
    print "Starting:"
    print (iteration''' 0 1 3 0 50000000)

iteration''' :: Int -> Int -> Int -> Int -> Int -> (Int, Int, Int, Int)
iteration''' cp cv ss zv 0 = (cp, cv, ss, zv)
iteration''' cp cv ss zv count = iteration''' ncp ncv ss nzv (count - 1)
    where
        ncp = ((cp   ss) `mod` cv)   1
        nzv = if ncp == 1 then cv else zv
        ncv = cv   1

The problem is that all three are horribly inefficient, both memory-wise and CPU-wise. The equivalent C-code would be something like the following (completing very quickly).

int stepSize = 3;
int zv = 0;
int position = 0;
for (int i = 0; i < 50000000; i  ) {
    position = (position   stepSize) % i;
    if (position == 0) zv = i;
}

I assumed iteration''' would be able to compile to something similar - but it eats up gigabytes of memory and loops for a long time.

To summarize my question - what is an idiomatic way to "solve this problem efficiently" in Haskell? Why is it eating up so much heap space when there is no actual object turnover necessary?

I am compiling using ghc (cabal).

CodePudding user response:

For completeness, as answered by Daniel Wagner and chi:

The problem in the stated code was strictness (implicitly a large amount of lazily evaluated ints resulted in a large amount of overhead).

This method is a lot faster (adding the BangPatterns header)

iteration''' :: Int -> Int -> Int -> Int -> Int -> (Int, Int, Int, Int)
iteration''' !cp !cv !ss !zv 0 = (cp, cv, ss, zv)
iteration''' !cp !cv !ss !zv !count = iteration''' ncp ncv ss nzv (count - 1)
    where
        ncp = ((cp   ss) `mod` cv)   1
        nzv = if ncp == 1 then cv else zv
        ncv = cv   1

I think the implication is that this is the idiomatic way to write (some) performant code as well!

  • Related