Julia:成对距离的嵌套循环真的很慢

sci*_*ctn 4 julia

我有一些代码可以加载 2000 个 2D 坐标的 csv 文件,然后一个名为的函数collision_count计算比d彼此距离更近的坐标对的数量:

using BenchmarkTools
using CSV
using LinearAlgebra

function load_csv()::Array{Float64,2}
    df = CSV.read("pos.csv", header=0)
    return Matrix(df)'
end

function collision_count(pos::Array{Float64,2}, d::Float64)::Int64
    count::Int64 = 0
    N::Int64 = size(pos, 2)
    for i in 1:N
        for j in (i+1):N
            @views dist = norm(pos[:,i] - pos[:,j])
            count += dist < d
        end
    end
    return count
end
Run Code Online (Sandbox Code Playgroud)

结果如下:

pos = load_csv()

@benchmark collision_count($pos, 2.0)
BenchmarkTools.Trial: 
  memory estimate:  366.03 MiB
  allocs estimate:  5997000
  --------------
  minimum time:     152.070 ms (18.80% GC)
  median time:      158.915 ms (20.60% GC)
  mean time:        158.751 ms (20.61% GC)
  maximum time:     181.726 ms (21.98% GC)
  --------------
  samples:          32
  evals/sample:     1
Run Code Online (Sandbox Code Playgroud)

这比这个 Python 代码慢了大约 30 倍:

import numpy as np
import scipy.spatial.distance

pos = np.loadtxt('pos.csv',delimiter=',')

def collision_count(pos, d):
    pdist = scipy.spatial.distance.pdist(pos)
    return np.count_nonzero(pdist < d)

%timeit collision_count(pos, 2)

5.41 ms ± 63 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Run Code Online (Sandbox Code Playgroud)

有什么办法可以让它更快?所有的分配是怎么回事?

Osc*_*ith 5

我可以轻松获得的最快速度如下

using Distances
using StaticArrays
using LinearAlgebra

pos = [@SVector rand(2) for _ in 1:2000]
function collision_count(pos::Vector{<:AbstractVector}, d)
    count = 0
    @inbounds for i in axes(pos,2)
        for j in (i+1):lastindex(pos,2)
            dist = sqeuclidean(pos[i], pos[j])
            count += dist < d*d
        end
    end
    return count
end
Run Code Online (Sandbox Code Playgroud)

这里有各种各样的变化,有些是风格上的,有些是结构上的。从样式开始,您可能会注意到我没有输入任何比我需要的更严格的内容。这没有性能优势,因为 Julia 足够聪明,可以为您的代码推断类型。

最大的结构变化是从使用矩阵转换为向量StaticVectors。这种变化的原因是因为点是您的标量类型,所以拥有一个元素向量更有意义,其中每个元素都是一个点。我所做的下一个更改是使用平方范数,因为sqrt操作很昂贵。结果不言自明:

@benchmark collision_count($pos, .1)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     1.182 ms (0.00% GC)
  median time:      1.214 ms (0.00% GC)
  mean time:        1.218 ms (0.00% GC)
  maximum time:     2.160 ms (0.00% GC)
  --------------
  samples:          4101
  evals/sample:     1
Run Code Online (Sandbox Code Playgroud)

请注意,有些n log(n)算法可能会更快,但这对于幼稚的实现应该非常接近最佳。