Home > Blockchain >  Why do these minor changes to f# linear algebra functions make them so much more performant?
Why do these minor changes to f# linear algebra functions make them so much more performant?

Time:10-28

I am fairly new to programming in f# and while I love a lot about it, one thing that bothers me is how often writing code in what feels like the 'proper' way in f# leads to it being incredibly slow. I have been working on a Neural Network in f# and my code incredibly slow compared to my implementations in other languages. One specific case is the following linear algebra functions:

// Dot Product

// Slow
let rec dotProduct (vector1 : float []) (vector2 : float []) : float =

    if vector1 |> Array.length = 0 then
        0.0
    else
        vector1.[0] * vector2.[0]   (dotProduct (vector1 |> Array.tail) (vector2 |> Array.tail))

// Fast
let dotProduct (vector1 : float []) (vector2 : float [])  =
    Array.fold2 (fun state x y -> state   x * y) 0.0 vector1 vector2
// Matrix Vector Product

// Slow

let matrixVectorProduct (matrix : float [,]) (vector : float[]) : float [] =
    [|
        for i = 0 to (matrix |> Array2D.length1) - 1 do
            yield dotProduct matrix.[i, 0..] vector
    |]

// Fast
let matrixVectorProduct (matrix : float [,]) (vector : float[]) : float [] =
    
    let mutable product = Array.zeroCreate (matrix |> Array2D.length1)
    
    for i = 0 to (matrix |> Array2D.length1) - 1 do
        product.[i] <- (dotProduct matrix.[i, 0..] vector)
    
    product

I was wondering if anyone with more f# experience could explain why exactly the second case is faster with each example, in terms of how the computer is interpreting my code. The biggest pain in coding with a high level language like f# is it's hard to know how your code is being optimized and run compared to programming in a low level language.

CodePudding user response:

For the first code sample:

Your slow dotProduct function is doing two things that impact CPU perf, in order of impact:

  1. Re-allocating an array each recursive call
  2. Not using tail recursion

The 2nd point isn't really that big of a deal from what I measured.

For the second sample:

Your slow version is slow because the F# array expression is not fixed. It needs to allocate and use an enumerator to generate the next item until it's done. In your more iterative code, you've pre-allocated a fixed array and you're just filling it in. This is always significantly faster, and when performance is your concern, mutation and loops is usually a good way to win.

There's another way to speed up your dot product code though: just do a simple loop!

let dotProductLoop (vector1 : float []) (vector2 : float []) : float =
    let mutable acc = 0.0

    for idx = 0 to vector1.Length - 1 do
        acc <- acc   (vector1.[idx] * vector2.[idx])
    
    acc

You'll notice that [fold2][1] more or less does this, but it comes with some marginal overhead.

I threw each approach into a benchmark to see some comparative results. As you can see, the loop approach is even faster than the fold2 call, but both are so much faster than your initial implementation that it's a clear win to take either.


BenchmarkDotNet=v0.13.1, OS=Windows 10.0.19042.1288 (20H2/October2020Update)
AMD Ryzen 9 5900X, 1 CPU, 24 logical and 12 physical cores
.NET SDK=6.0.100-rc.2.21505.57
  [Host]     : .NET 6.0.0 (6.0.21.48005), X64 RyuJIT DEBUG
  DefaultJob : .NET 6.0.0 (6.0.21.48005), X64 RyuJIT


Method Mean Error StdDev Ratio Gen 0 Gen 1 Allocated
DotProduct 320,534.9 ns 1,738.55 ns 1,626.24 ns 1.000 480.4688 18.0664 8,040,000 B
DotProductLoop 625.4 ns 1.93 ns 1.71 ns 0.002 - - -
DotProductFold 1,105.1 ns 10.77 ns 10.07 ns 0.003 - - -

Another thing you can do, if you're committed to writing recursive code, is to have a private helper function that does tail recursion on Span or ReadonlySpan:

let rec private dotProductImpl (vector1 : Span<float>) (vector2 : Span<float>) (acc: float) =
        if vector1.Length = 0 then
            acc
        else
            dotProductImpl (vector1.Slice(1)) (vector2.Slice(1)) (acc   vector1.[0] * vector2.[0])

The function that calls this will perform just as well as the loop that I proposed.

  • Related