Home > Net >  Fast idiomatic Floyd-Warshall algorithm in Rust
Fast idiomatic Floyd-Warshall algorithm in Rust

Time:11-22

I am trying to implement a reasonably fast version of Floyd-Warshall algorithm in Rust. This algorithm finds a shortest paths between all vertices in a directed weighted graph.

The main part of the algorithm could be written like this:

// dist[i][j] contains edge length between vertices [i] and [j]
// after the end of the execution it contains shortest path between [i] and [j]
fn floyd_warshall(dist: &mut [Vec<i32>]) {
    let n = dist.len();
    for i in 0..n {
        for j in 0..n {
            for k in 0..n {
                dist[j][k] = min(dist[j][k], dist[j][i]   dist[i][k]);
            }
        }
    }
}

This implementation is very short and easy to understand, but it works 1.5x slower than similar c implementation.

As I understand problem is that on each vector access Rust checks that index is inside bounds of the vector, and it adds some overhead.

I rewrote this function with get_unchecked* functions:

fn floyd_warshall_unsafe(dist: &mut [Vec<i32>]) {
    let n = dist.len();
    for i in 0..n {
        for j in 0..n {
            for k in 0..n {
                unsafe {
                    *dist[j].get_unchecked_mut(k) = min(
                        *dist[j].get_unchecked(k),
                        dist[j].get_unchecked(i)   dist[i].get_unchecked(k),
                    )
                }
            }
        }
    }
}

and it really started working 1.5x faster (full code of the test).

I didn't expect bounds check to add that much overhead :(

Is it possible to rewrite this code in an idiomatic way without unsafes, such that it works as fast as unsafe version? E.g. is it possible to "prove" to the compiler that there will be no out of bound access by adding some assertions to the code?

CodePudding user response:

At first blush, one would hope this would be enough:

fn floyd_warshall(dist: &mut [Vec<i32>]) {
    let n = dist.len();
    for i in 0..n {
        assert!(i < dist.len());
        for j in 0..n {
            assert!(j < dist.len());
            assert!(i < dist[j].len());
            let v2 = dist[j][i];
            for k in 0..n {
                assert!(k < dist[i].len());
                assert!(k < dist[j].len());
                dist[j][k] = min(dist[j][k], v2   dist[i][k]);
            }
        }
    }
}

Adding asserts is a known trick to convince the Rust optimizer that variables are indeed in bounds. However, it doesn't work here. What we need to do is somehow make it even more obvious to the Rust compiler that these loops are in bounds without resorting to esoteric code.

To accomplish that, I moved to a 2D array as suggested by David Eisenstat:

fn floyd_warshall<const N:usize>(mut dist: Box<[[i32; N]; N]>) -> Box<[[i32; N]; N]> {
    for i in 0..N {
        for j in 0..N {
            for k in 0..N {
                dist[j][k] = min(dist[j][k], dist[j][i]   dist[i][k]);
            }
        }
    }
    dist
}

This uses constant generics, a relatively new feature of Rust, to specify the size of a given 2d array on the heap. On its own, this change does well on my machine (100ms faster than usafe, and ~20ms behind unsafe). If, additionally, you move the v2 calculation outside the k-loop like this:

fn floyd_warshall<const N:usize>(mut dist: Box<[[i32; N]; N]>) -> Box<[[i32; N]; N]> {
    for i in 0..N {
        for j in 0..N {
            let v2 = dist[j][i];
            for k in 0..N {
                dist[j][k] = min(dist[j][k], v2   dist[i][k]);
            }
        }
    }
    dist
}

The improvement is substantial (from ~300ms to ~100ms on my machine). The same optimization works with floyd_warshall_unsafe bringing it to ~100ms on average on my machine. When inspecting the assembly (with #[inline(never)] on floyd_warshall) it doesn't look like bounds checks occur for either, and both look vectorized to some extent. Although, I am no expert at reading assembly.

Because this is such a hot loop (with up to three bounds checks) I'm not surprised performance suffers so much. Unfortunately, the usage of the indices in this case is complicated enough to prevent the assert trick from giving you an easy fix. There are other known cases where an assert check is desired to improve the performance but the compiler is unable to use the information sufficiently. Here is one such example.

Here is the playground with my changes.

CodePudding user response:

After some experiments, based on ideas suggested in Andrew's answer, and comments in related issue I found solution, which:

  • still uses the same interface (e.g. &mut [Vec<i32>] as argument)
  • does't use unsafe
  • 3-4x faster than unsafe version
  • quite ugly :(

Code looks like this:

fn floyd_warshall_fast(dist: &mut [Vec<i32>]) {
    let n = dist.len();
    for i in 0..n {
        for j in 0..n {
            if i == j {
                continue;
            }
            let (dist_j, dist_i) = if j < i {
                let (lo, hi) = dist.split_at_mut(i);
                (&mut lo[j][..n], &mut hi[0][..n])
            } else {
                let (lo, hi) = dist.split_at_mut(j);
                (&mut hi[0][..n], &mut lo[i][..n])
            };
            let dist_ji = dist_j[i];
            for k in 0..n {
                dist_j[k] = min(dist_j[k], dist_ji   dist_i[k]);
            }
        }
    }
}

There are couple of ideas inside:

  • We compute dist_ji once as it doesn't change inside the most inner cycle, and compiler doesn't need to think about it.
  • We "prove" that dist[i] and dist[j] are actually two different vectors. This is done by this ugly split_at_mut thing and i == j special case (would really love to know an easier solution). After that we can treat dist[i] and dist[j] absolutely separately, and for example the compiler can vectorize this loop, because it knows that data doesn't overlap.
  • Last trick is to "prove" to the compiler that both dist[i] and dist[j] have at least n elements. This is done by [..n] when computing dist[i] and dist[j] (e.g. we use &mut lo[j][..n] instead of just &mut lo[j]). After that, compiler understands that the inner loop never uses out of bounds values, and removes checks.

Interesting that only when all three optimizations are used, it gives a big speed up. If we use only any two of them, the compiler can't optimize it.

  • Related