I have a program which traverses an expression tree that does algebra on probability distributions, either sampling or computing the resulting distribution.
I have two implementations computing the distribution: one (computeDistribution
) nicely reusable with monad transformers and one (simpleDistribution
) where I concretize everything by hand. I would like to not concretize everything by hand, since that would be code duplication between the sampling and computing code.
I also have two data representations:
type Measure a = [(a, Rational)]
-- data Distribution a = Distribution (Measure a) deriving Show
newtype Distribution a = Distribution (Measure a) deriving Show
When I use the data
version with the reusable code, computing the distribution of 20d2 (ghc -O3 program.hs; time ./program 20 > /dev/null
) takes about one second, which seems way too long. Pick higher values of n
at your own peril.
When I use the hand-concretized code, or I use the newtype
representation with either implementation, computing 20d2 (time ./program 20 s > /dev/null
) takes the blink of an eye.
Why?
How can I find out why?
My knowledge of how Haskell is executed is almost nil. I gather there's a graph of thunks in basically the same shape as the program, but that's about all I know.
I figure with newtype
the representation of Distribution
is the same as that of Measure
, i.e. it's just a list, whereas with the data
version each Distribution
is kinda' like a single-field record, except with a pointer to the contained list, and so the data
version has to perform more allocations. Is this true? If true, is this enough to explain the performance difference?
I'm new to working with monad transformer stacks. Consider the Let
and Uniform
cases in simpleDistribution
— do they do the same as the walkTree
-based implementation? How do I tell?
Here's my program. Note that Uniform n
corresponds to rolling an n-sided die (in case the unary-ness was surprising).
Update: based on comments I simplified my program by removing everything not contributing to the performance gap. I made two semantic changes: probabilities are now denormalized and all wonky and wrong, and the simplification step is gone. But the essential shape of my program is still there. (See question edit history for the non-simplified program.)
Update 2: I made further simplifications, reducing Distribution
down to the list monad with a small twist, removing everything to do with probabilities, and shortening the names. I still observe large performance differences when using data
but not newtype
.
import Control.Monad (liftM2)
import Control.Monad.Trans (lift)
import Control.Monad.Reader (ReaderT, runReaderT)
import System.Environment (getArgs)
import Text.Read (readMaybe)
main = do
args <- getArgs
let dieCount = case map readMaybe args of Just n : _ -> n; _ -> 10
let f = if ["s"] == (take 1 $ drop 1 $ args) then fast else slow
print $ f dieCount
fast, slow :: Int -> P Integer
fast n = walkTree n
slow n = walkTree n `runReaderT` ()
walkTree 0 = uniform
walkTree n = liftM2 ( ) (walkTree 0) (walkTree $ n - 1)
data P a = P [a] deriving Show
-- newtype P a = P [a] deriving Show
class Monad m => MonadP m where uniform :: m Integer
instance MonadP P where uniform = P [1, 1]
instance MonadP p => MonadP (ReaderT env p) where uniform = lift uniform
instance Functor P where fmap f (P pxs) = P $ fmap f pxs
instance Applicative P where
pure x = P [x]
(P pfs) <*> (P pxs) = P $ pfs <*> pxs
instance Monad P where
(P pxs) >>= f = P $ do
x <- pxs
case f x of P fxs -> fxs
CodePudding user response:
How can I find out why?
This is, in general, hard.
The extreme way to do it is to look at the core code (which you can produce by running GHC with -ddump-simpl
). This can get complicated really quickly, and it's basically a whole new language to learn. Your program is already big enough that I had trouble learning much from the core dump.
The other way to find out why is to just keep using GHC and asking questions and learning about GHC optimizations until you recognize certain patterns.
Why?
In short, I believe it's due to list fusion.
NOTE: I don't know for sure that this answer is correct, and it would take more time/work to verify than I'm willing to put in right now. That said, it fits the evidence.
First off, we can check whether this slowdown you're seeing is a result of something truly fundamental vs a GHC optimization triggering or not by running in O0
, that is, without optimizations. In this mode, both Distribution
representations result in about the same (excruciatingly long) runtime. This leads me to believe that it's not the data representation that is inherently the problem but rather there's an optimization that's triggered with the newtype
version that isn't with the data
version.
When GHC is run in -O1
or higher, it engages certain rewrite rules to fuse different folds and maps of lists together so that it doesn't need to allocate intermediate values. (See https://markkarpov.com/tutorial/ghc-optimization-and-fusion.html#fusion for a decent tutorial on this concept as well as https://stackoverflow.com/a/38910170/14802384 which additionally has a link to a gist with all of the rewrite rules in base
.) Since computeDistribution
is basically just a bunch of list manipulations (which are all essentially folds), there is the potential for these to fire.
The key is that with the newtype
representation of Distribution
, the newtype wrapper is erased during compilation, and the list operations are allowed to fuse. However, with the data
representation, the wrappers are not erased, and the rewrite rules do not fire.
Therefore, I will make an unsubstantiated claim: If you want your data
representation to be as fast as the newtype
one, you will need to set up rewrite rules similar to the ones for list folding but that work over the Distribution
type. This may involve writing your own special fold functions and then rewriting your Functor/Applicative/Monad instances to use them.