Home > Blockchain >  Automatically detect whether a Haskell function is tail recursive
Automatically detect whether a Haskell function is tail recursive

Time:01-25

I'm currently writing an auto-grader for a Haskell course. For the section on "tail-recursion", I need a way to automatically and safely detect whether a given Haskell function is tail-recursive or not.

I've searched for existing tools but couldn't find anything. I assume there must be a way to do this automatically, because after all that's what the Haskell compiler does for us. The method doesn't have to be in a specific language or anything since the grader is an external entity in the project. For example, it can be a Haskell library, a command line tool, or code written in any other language (C, Java, Python, etc).

If there actually isn't any such tools, I assume I'm gonna have to use something like a lexical analyzer for Haskell, and write custom code that detects tail recursion myself.

CodePudding user response:

I would first point out that tail recursion is rarely a virtue in Haskell. It's fine if you want to use Haskell as a medium for teaching tail recursion, but actually writing tail recursive functions in Haskell is usually misguided.

Presuming you still want to do this, I would highlight

after all that's what the Haskell compiler does for us

Yes, indeed. So why would any tool other than the compiler exist? The compiler already does exactly this. So, when you want to do this, use the compiler. I'm sure it won't be trivial, because you'll have to learn the compiler's types and other API. But it will actually be correct.

I would start by looking at a function like isAlwaysTailCalled, and see if that does what you want. If it doesn't, maybe you need to consume the AST of the function definition.

CodePudding user response:

I basically agree with amalloy, however for this auto-grader (which presumably should only be a quick way to weed out clear-cut mistakes, not a complete reliable certification tool) I would just cobble something together in Template Haskell.

{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE LambdaCase      #-}

module TailRecCheck where

import Language.Haskell.TH

isTailRecursive :: Dec -> Bool
isTailRecursive (FunD fName clauses) = all isClauseTR clauses
 where 
       isClauseTR (Clause _ (NormalB (AppE f x)) _)
          -- Function application that could be a tail call
          = case f of
             -- It's only a tail call if the function is the
             -- one we're currently defining, and if the rest
             -- is not recursive
             VarE fn -> fn==fName && isConstant x
       -- Constant expressions are allowed as base cases
       isClauseTR (Clause _ (NormalB body) _) = isConstant body
       --isClauseTR _ ... _ = ...

       isConstant (VarE n) = n /= fName
       isConstant (ConE _) = True
       isConstant (LitE _) = True
       isConstant (AppE f x) = isConstant f && isConstant x
       isConstant (InfixE l op r)
          = all isConstant l && isConstant op && all isConstant r
       --isConstant ... = ...

assertTailRecursiveDefs :: Q [Dec] -> Q [Dec]
assertTailRecursiveDefs n = n >>= mapM`id`\case
   dec
     | isTailRecursive dec -> return dec
     | otherwise -> fail ("Function "  showName dec
                                " is not tail recursive.")
 where showName (FunD n _) = show n

To be used like

{-# LANGUAGE TemplateHaskell #-}

module TailRecCheckExample where

import TailRecCheck

assertTailRecursiveDefs [d|
    f x = 4

    g 0 = 1
    g x = g (x-1)

    h 0 = 1
    h x = 1   h (x-1)
  |]
TailRecCheckExample.hs:7:1: error:
    Function h_6989586621679042356 is not tail recursive.
  |
7 | assertTailRecursiveDefs [d|
  | ^^^^^^^^^^^^^^^^^^^^^^^^^^^...
  • Related