Rust 中的快速惯用 Floyd-Warshall 算法

Bor*_*rys 17 algorithm graph-algorithm rust

我正在尝试在 Rust 中实现Floyd-Warshall算法的相当快的版本。该算法找到有向加权图中所有顶点之间的最短路径。

算法的主要部分可以写成这样:

// dist[i][j] contains edge length between vertices [i] and [j]
// after the end of the execution it contains shortest path between [i] and [j]
fn floyd_warshall(dist: &mut [Vec<i32>]) {
    let n = dist.len();
    for i in 0..n {
        for j in 0..n {
            for k in 0..n {
                dist[j][k] = min(dist[j][k], dist[j][i] + dist[i][k]);
            }
        }
    }
}
Run Code Online (Sandbox Code Playgroud)

此实现非常简短且易于理解,但其运行速度比类似的 C++ 实现慢 1.5 倍。

据我了解,问题在于,在每个向量访问上,Rust 检查索引是否在向量的范围内,这会增加一些开销。

我用get_unchecked * 函数重写了这个函数:

fn floyd_warshall_unsafe(dist: &mut [Vec<i32>]) {
    let n = dist.len();
    for i in 0..n {
        for j in 0..n {
            for k in 0..n {
                unsafe {
                    *dist[j].get_unchecked_mut(k) = min(
                        *dist[j].get_unchecked(k),
                        dist[j].get_unchecked(i) + dist[i].get_unchecked(k),
                    )
                }
            }
        }
    }
}
Run Code Online (Sandbox Code Playgroud)

它的工作速度确实提高了 1.5 倍(测试的完整代码)。

我没想到边界检查会增加那么多开销:(

是否可以在没有不安全的情况下以惯用的方式重写此代码,使其与不安全版本一样快?例如,是否可以通过在代码中添加一些断言来向编译器“证明”不会出现越界访问?

Bor*_*rys 6

经过一些实验,根据安德鲁的回答中提出的想法以及相关问题中的评论,我找到了解决方案,其中:

  • 仍然使用相同的接口(例如&mut [Vec<i32>]作为参数)
  • 不使用不安全的
  • 比不安全版本快 3-4 倍
  • 相当丑陋:(

代码如下所示:

fn floyd_warshall_fast(dist: &mut [Vec<i32>]) {
    let n = dist.len();
    for i in 0..n {
        for j in 0..n {
            if i == j {
                continue;
            }
            let (dist_j, dist_i) = if j < i {
                let (lo, hi) = dist.split_at_mut(i);
                (&mut lo[j][..n], &mut hi[0][..n])
            } else {
                let (lo, hi) = dist.split_at_mut(j);
                (&mut hi[0][..n], &mut lo[i][..n])
            };
            let dist_ji = dist_j[i];
            for k in 0..n {
                dist_j[k] = min(dist_j[k], dist_ji + dist_i[k]);
            }
        }
    }
}
Run Code Online (Sandbox Code Playgroud)

里面有几个想法:

  • 我们计算dist_ji一次,因为它在最内部的循环内不会改变,编译器不需要考虑它。
  • 我们“证明”dist[i]dist[j]实际上是两个不同的向量。这是由这个丑陋的split_at_mut东西和i == j特殊情况完成的(真的很想知道一个更简单的解决方案)。之后我们可以完全分开对待dist[i]dist[j],例如编译器可以向量化这个循环,因为它知道数据不会重叠。
  • 最后一个技巧是向编译器“证明”dist[i]和都dist[j]至少有n元素。这是通过[..n]计算dist[i]dist[j](例如我们使用&mut lo[j][..n]而不是仅仅&mut lo[j])来完成的。之后,编译器知道内部循环永远不会使用越界值,并删除检查。

有趣的是,只有当所有三种优化都被使用时,它才会带来很大的加速。如果我们只使用其中任意两个,编译器就无法对其进行优化。


And*_*uke 5

乍一看,人们希望这已经足够了:

fn floyd_warshall(dist: &mut [Vec<i32>]) {
    let n = dist.len();
    for i in 0..n {
        assert!(i < dist.len());
        for j in 0..n {
            assert!(j < dist.len());
            assert!(i < dist[j].len());
            let v2 = dist[j][i];
            for k in 0..n {
                assert!(k < dist[i].len());
                assert!(k < dist[j].len());
                dist[j][k] = min(dist[j][k], v2 + dist[i][k]);
            }
        }
    }
}
Run Code Online (Sandbox Code Playgroud)

添加断言是一个众所周知的技巧,可以让 Rust 优化器相信变量确实在范围内。然而,它在这里不起作用。我们需要做的是以某种方式让 Rust 编译器更明显地知道这些循环是在边界内的,而不需要求助于深奥的代码。

为了实现这一目标,我按照 David Eisenstat 的建议转向了二维数组:

fn floyd_warshall<const N:usize>(mut dist: Box<[[i32; N]; N]>) -> Box<[[i32; N]; N]> {
    for i in 0..N {
        for j in 0..N {
            for k in 0..N {
                dist[j][k] = min(dist[j][k], dist[j][i] + dist[i][k]);
            }
        }
    }
    dist
}
Run Code Online (Sandbox Code Playgroud)

它使用常量泛型(Rust 的一个相对较新的功能)来指定堆上给定二维数组的大小。就其本身而言,此更改在我的机器上表现良好(比 usafe 快 100 毫秒,比 unsafe 慢约 20 毫秒)。另外,如果您将 v2 计算移到 k 循环之外,如下所示:

fn floyd_warshall<const N:usize>(mut dist: Box<[[i32; N]; N]>) -> Box<[[i32; N]; N]> {
    for i in 0..N {
        for j in 0..N {
            let v2 = dist[j][i];
            for k in 0..N {
                dist[j][k] = min(dist[j][k], v2 + dist[i][k]);
            }
        }
    }
    dist
}
Run Code Online (Sandbox Code Playgroud)

改进是显着的(在我的机器上从约 300 毫秒到约 100 毫秒)。同样的优化floyd_warshall_unsafe在我的机器上平均达到约 100 毫秒。当检查程序集时(使用#[inline(never)]floyd_warshall),看起来两者都没有发生边界检查,并且两者看起来都在某种程度上矢量化。虽然,我不是阅读汇编的专家。

因为这是一个如此热的循环(最多三个边界检查),所以我对性能受到如此大的影响并不感到惊讶。不幸的是,在这种情况下索引的使用非常复杂,以至于无法通过断言技巧为您提供简单的修复。还有其他已知的情况,需要断言检查来提高性能,但编译器无法充分使用该信息。这是一个这样的例子

这是我所做的改变的游乐场