Let's say I have a binary tree:
data BinTree a
= Nil
| Branch a (BinTree a) (BinTree a)
I'd like to do an accumulate map on a structure like this:
mapAccum ::
(
)
=> (a -> b -> (c, a)) -> a -> BinTree b -> BinTree c
mapAccum func x Nil =
Nil
mapAccum func x (Branch y left right) =
let
(y', x') =
func x y
in
Branch y' (mapAccum func x' left) (mapAccum func x' right)
Which performs a map with an accumulator across the structure.
However this is a very general pattern. We can do this on all sorts of tree like structures, and if there's a nice, common-place abstraction I'd prefer to use that over rolling my own here.
There's a function on Traversable
s:
mapAccumL :: Traversable t => (s -> a -> (s, b)) -> s -> t a -> (s, t b)
Which sort of does the same thing on lists. But it requires Foldable
in a way that means it wouldn't work on binary trees. What I'm looking for would be a more basic version of this which works without Foldable
.
I can make it work on types made with Cofree
:
mapAccum ::
( Functor f
)
=> (a -> b -> (c, a)) -> a -> Cofree f b -> Cofree f c
mapAccum func x (y :< rest) =
let
(y', x') =
func x y
in
y' :< fmap (mapAccum func x') rest
Which shows that it's at least generally applicable to tree-like structures.
Is there a common abstraction for this function?
CodePudding user response:
import Data.Functor.Foldable.TH
import Data.Functor.Foldable
data BinTree a
= Nil
| Branch a (BinTree a) (BinTree a)
deriving (Functor)
makeBaseFunctor ''BinTree
The Template Haskell splice will define
data BinTreeF a x
= NilF
| BranchF a x x
deriving (Functor)
as well as instances of the Base
type family and the Recursive
and Corecursive
classes for BinTree
. Thanks to that Recursive
instance, you can use
cata :: Recursive t => (Base t r -> r) -> t -> r
at the type
cata :: (BinTreeF b (BinTree c) -> BinTree c) -> BinTree b -> BinTree c
Specifically,
mapAccum:: a -> b -> (c, a)) -> a -> BinTree b -> BinTree c
mapAccum func = \x t -> cata go t x
where
go NilF _x = Nil
go (BranchF y leftres rightres) x =
let
(y', x') = func x y
in
Branch y' (leftres x') (rightres x')
Alternatively, you could use
transverse :: (Recursive s, Corecursive t, Functor f) => (forall r. Base s (f r) -> f (Base t r)) -> s -> f t
at the type
transverse :: (forall r. BinTreeF b (a -> r) -> a -> BinTreeF c r) -> BinTree b -> a -> BinTree c
like so
mapAccum:: forall a b c. (a -> b -> (c, a)) -> a -> BinTree b -> BinTree c
mapAccum func = \x t -> transverse go t x
where
go :: BinTreeF b (a -> r) -> a -> BinTreeF c r
go NilF _x = NilF
go (BranchF y leftres rightres) x =
let
(y', x') = func x y
in
BranchF y' (leftres x') (rightres x')
CodePudding user response:
Here's a generalization of what you wrote, using the bifunctors
and recursion-schemes
packages:
{-# LANGUAGE TemplateHaskell, TypeFamilies #-}
import Control.Monad.Trans.State.Lazy
import Data.Bifunctor.TH
import Data.Bitraversable
import Data.Functor.Foldable
import Data.Functor.Foldable.TH
data BinTree a = Nil | Branch a (BinTree a) (BinTree a)
makeBaseFunctor ''BinTree
deriveBifunctor ''BinTreeF
deriveBifoldable ''BinTreeF
deriveBitraversable ''BinTreeF
mapAccum :: (Base tc ~ f c, Base tb ~ f b, Bitraversable f, Recursive tb,
Corecursive tc) => (a -> b -> (c, a)) -> a -> tb -> tc
mapAccum func x ys = embed ys' where
(ys', x') = runState (bitraverse (state . flip func) (pure . mapAccum func x') (project ys)) x
-- a slightly less general version, but that's usually good enough,
-- and will fix most ambiguous type errors
mapAccum' :: (Base (t c) ~ f c, Base (t b) ~ f b, Bitraversable f, Recursive (t b),
Corecursive (t c)) => (a -> b -> (c, a)) -> a -> t b -> t c
mapAccum' = mapAccum
The way it works is that it traverses over all of the values in the current element of the tree (for your tree in particular, this is always just one element), transforming them and coming up with a new accumulator value, then recursively calling itself on each subelement of the tree with that value. Also, since it's lazy state, it ties the knot so it only has to walk the structure once instead of twice. In other words, note that x'
comes from the output of runState
, but it's passed as part of the parameters to it. In a strict language, this would result in an infinite loop, but since Haskell is lazy, x'
isn't evaluated until it's needed, at which point the part of the code that generated it is finished.