Home > Enterprise >  Is there a map accumulate for things that aren't foldable?
Is there a map accumulate for things that aren't foldable?

Time:12-29

Let's say I have a binary tree:

data BinTree a
  = Nil
  | Branch a (BinTree a) (BinTree a)

I'd like to do an accumulate map on a structure like this:

mapAccum ::
  (
  )
    => (a -> b -> (c, a)) -> a -> BinTree b -> BinTree c
mapAccum func x Nil =
  Nil
mapAccum func x (Branch y left right) =
  let
    (y', x') =
      func x y
  in
    Branch y' (mapAccum func x' left) (mapAccum func x' right) 

Which performs a map with an accumulator across the structure.

However this is a very general pattern. We can do this on all sorts of tree like structures, and if there's a nice, common-place abstraction I'd prefer to use that over rolling my own here.

There's a function on Traversables:

mapAccumL :: Traversable t => (s -> a -> (s, b)) -> s -> t a -> (s, t b)

Which sort of does the same thing on lists. But it requires Foldable in a way that means it wouldn't work on binary trees. What I'm looking for would be a more basic version of this which works without Foldable.

I can make it work on types made with Cofree:

mapAccum ::
  ( Functor f
  )
    => (a -> b -> (c, a)) -> a -> Cofree f b -> Cofree f c
mapAccum func x (y :< rest) =
  let
    (y', x') =
      func x y
  in
    y' :< fmap (mapAccum func x') rest

Which shows that it's at least generally applicable to tree-like structures.

Is there a common abstraction for this function?

CodePudding user response:

import Data.Functor.Foldable.TH
import Data.Functor.Foldable

data BinTree a
  = Nil
  | Branch a (BinTree a) (BinTree a)
  deriving (Functor)

makeBaseFunctor ''BinTree

The Template Haskell splice will define

data BinTreeF a x
  = NilF
  | BranchF a x x
  deriving (Functor)

as well as instances of the Base type family and the Recursive and Corecursive classes for BinTree. Thanks to that Recursive instance, you can use

cata :: Recursive t => (Base t r -> r) -> t -> r

at the type

cata :: (BinTreeF b (BinTree c) -> BinTree c) -> BinTree b -> BinTree c

Specifically,

mapAccum:: a -> b -> (c, a)) -> a -> BinTree b -> BinTree c
mapAccum func = \x t -> cata go t x
  where
    go NilF _x = Nil
    go (BranchF y leftres rightres) x =
      let
        (y', x') = func x y
      in
        Branch y' (leftres x') (rightres x') 

Alternatively, you could use

transverse :: (Recursive s, Corecursive t, Functor f) => (forall r. Base s (f r) -> f (Base t r)) -> s -> f t

at the type

transverse :: (forall r. BinTreeF b (a -> r) -> a -> BinTreeF c r) -> BinTree b -> a -> BinTree c

like so

mapAccum:: forall a b c. (a -> b -> (c, a)) -> a -> BinTree b -> BinTree c
mapAccum func = \x t -> transverse go t x
  where
    go :: BinTreeF b (a -> r) -> a -> BinTreeF c r
    go NilF _x = NilF
    go (BranchF y leftres rightres) x =
      let
        (y', x') = func x y
      in
        BranchF y' (leftres x') (rightres x') 

CodePudding user response:

Here's a generalization of what you wrote, using the bifunctors and recursion-schemes packages:

{-# LANGUAGE TemplateHaskell, TypeFamilies #-}

import Control.Monad.Trans.State.Lazy
import Data.Bifunctor.TH
import Data.Bitraversable
import Data.Functor.Foldable
import Data.Functor.Foldable.TH

data BinTree a = Nil | Branch a (BinTree a) (BinTree a)
makeBaseFunctor ''BinTree
deriveBifunctor ''BinTreeF
deriveBifoldable ''BinTreeF
deriveBitraversable ''BinTreeF

mapAccum :: (Base tc ~ f c, Base tb ~ f b, Bitraversable f, Recursive tb,
  Corecursive tc) => (a -> b -> (c, a)) -> a -> tb -> tc
mapAccum func x ys = embed ys' where
  (ys', x') = runState (bitraverse (state . flip func) (pure . mapAccum func x') (project ys)) x

-- a slightly less general version, but that's usually good enough,
-- and will fix most ambiguous type errors
mapAccum' :: (Base (t c) ~ f c, Base (t b) ~ f b, Bitraversable f, Recursive (t b),
  Corecursive (t c)) => (a -> b -> (c, a)) -> a -> t b -> t c
mapAccum' = mapAccum

The way it works is that it traverses over all of the values in the current element of the tree (for your tree in particular, this is always just one element), transforming them and coming up with a new accumulator value, then recursively calling itself on each subelement of the tree with that value. Also, since it's lazy state, it ties the knot so it only has to walk the structure once instead of twice. In other words, note that x' comes from the output of runState, but it's passed as part of the parameters to it. In a strict language, this would result in an infinite loop, but since Haskell is lazy, x' isn't evaluated until it's needed, at which point the part of the code that generated it is finished.

  • Related