将 F# 测量单位与 System.Numerics.Vector<T> 结合使用

Fra*_*ank 5 .net f# simd vectorization units-of-measurement

我很难将 F# 测量单位与System.Numerics.Vector<'T>类型结合使用。让我们看一个玩具问题:假设我们有一个类型xs的数组float<m>[],并且由于某种原因我们想要对其所有组件进行平方,从而得到一个类型的数组float<m^2>[]。这与标量代码完美配合:

xs |> Array.map (fun x -> x * x) // float<m^2>[]
Run Code Online (Sandbox Code Playgroud)

现在假设我们想要通过使用 SIMD 在 - 大小的块中执行乘法来向量化此操作System.Numerics.Vector<float>.Count,例如如下所示:

open System.Numerics
let simdWidth = Vector<float>.Count
// fill with dummy data
let xs = Array.init (simdWidth * 10) (fun i -> float i * 1.0<m>)
// array to store the results
let rs: float<m^2> array = Array.zeroCreate (xs |> Array.length)
// number of SIMD operations required
let chunks = (xs |> Array.length) / simdWidth
// for simplicity, assume xs.Length % simdWidth = 0
for i = 0 to chunks - 1 do
    let v = Vector<_>(xs, i * simdWidth) // Vector<float<m>>, containing xs.[i .. i+simdWidth-1]
    let u = v * v                        // Vector<float<m>>; expected: Vector<float<m^2>>
    u.CopyTo(rs, i * simdWidth)          // units mismatch
Run Code Online (Sandbox Code Playgroud)

我相信我明白为什么会发生这种情况:F# 编译器如何知道,System.Numerics.Vector<'T>.op_Multiply适用什么以及适用哪些算术规则?它实际上可以是任何操作。那么它应该如何推导出正确的单位呢?

问题是:实现这项工作的最佳方法是什么?我们如何告诉编译器适用哪些规则?

尝试 1:删除所有计量单位信息xs并稍后将其添加回来:

// remove all UoM from all arrays
let xsWoM = Array.map (fun x -> x / 1.0<m>) xs
// ...
// perform computation using xsWoM etc.
// ...
// add back units again
let xs = Array.map (fun x -> x * 1.0<m>) xsWoM
Run Code Online (Sandbox Code Playgroud)

问题:执行不必要的计算和/或复制操作,违背了出于性能原因对代码进行矢量化的目的。而且,这在很大程度上违背了使用 UoM 的初衷。

尝试 2:使用内联 IL 更改以下内容的返回类型Vector<'T>.op_Multiply

// reinterpret x to be of type 'b
let inline retype (x: 'a) : 'b = (# "" x: 'b #)
let inline (.*.) (u: Vector<float<'m>>) (v: Vector<float<'m>>): Vector<float<'m^2>> = u * v |> retype
// ...
let u = v .*. v // asserts type Vector<float<m^2>>
Run Code Online (Sandbox Code Playgroud)

问题:不需要任何额外的操作,但使用已弃用的功能(内联 IL)并且不完全通用(仅相对于度量单位)。

有人对此有更好的解决方案*吗

*请注意,上面的示例实际上是一个玩具问题,用于演示一般问题。真实的程序解决的是更为复杂的初值问题,涉及多种物理量。

Fra*_*ank 1

我想出了一个解决方案,可以满足我的大部分要求(看起来)。它的灵感来自TheInnerLight 的想法(wrapping Vector<'T>),但还ScalarField为底层数组数据类型添加了一个包装器(称为 )。这样我们就可以跟踪单位,而在底层我们只处理原始数据并且可以使用不单位感知的System.Numerics.VectorAPI。

一个简化的、简单的、快速的、肮脏的实现将如下所示:

// units-aware wrapper for System.Numerics.Vector<'T>
type PackedScalars<[<Measure>] 'm> = struct
    val public Data: Vector<float>
    new (d: Vector<float>) = {Data = d}
    static member inline (*) (u: PackedScalars<'m1>, v: PackedScalars<'m2>) = u.Data * v.Data |> PackedScalars<'m1*'m2>
end

// unit-ware type, wrapping a raw array for easy stream processing
type ScalarField<[<Measure>] 'm> = struct
    val public Data: float[]
    member self.Item with inline get i                = LanguagePrimitives.FloatWithMeasure<'m> self.Data.[i]
                     and  inline set i (v: float<'m>) = self.Data.[i] <- (float v)
    member self.Packed 
           with inline get i                        = Vector<float>(self.Data, i) |> PackedScalars<'m>
           and  inline set i (v: PackedScalars<'m>) = v.Data.CopyTo(self.Data, i)
    new (d: float[]) = {Data = d}
    new (count: int) = {Data = Array.zeroCreate count}
end
Run Code Online (Sandbox Code Playgroud)

我们现在可以使用这两种数据结构以相对优雅、有效的方式解决示例问题:

let xs = Array.init (simdWidth * 10) float |> ScalarField<m>    
let mutable rs = Array.zeroCreate (xs.Data |> Array.length) |> ScalarField<m^2>
let chunks = (xs.Data |> Array.length) / simdWidth
for i = 0 to chunks - 1 do
    let j = i * simdWidth
    let v = xs.Packed(j) // PackedScalars<m>
    let u = v * v        // PackedScalars<m^2>
    rs.Packed(j) <- u
Run Code Online (Sandbox Code Playgroud)

最重要的是,为单元感知ScalarField包装器重新实现常用的数组操作可能会很有用,例如

[<CompilationRepresentation(CompilationRepresentationFlags.ModuleSuffix)>]
module ScalarField =
    let map f (sf: ScalarField<_>) =
        let mutable res = Array.zeroCreate sf.Data.Length |> ScalarField
        for i = 0 to sf.Data.Length do
           res.[i] <- f sf.[i]
        res
Run Code Online (Sandbox Code Playgroud)

ETC。

缺点: 对于底层数字类型 ( float) 来说不是通用的,因为 没有通用的替代品floatWithMeasure。为了使其通用,我们必须实现第三个包装器Scalar,它也包装底层原语:

type Scalar<'a, [<Measure>] 'm> = struct
    val public Data: 'a
    new (d: 'a) = {Data = d}
end

type PackedScalars<'a, [<Measure>] 'm 
            when 'a: (new: unit -> 'a) 
            and  'a: struct 
            and  'a :> System.ValueType> = struct
    val public Data: Vector<'a>
    new (d: Vector<'a>) = {Data = d}
    static member inline (*) (u: PackedScalars<'a, 'm1>, v: PackedScalars<'a, 'm2>) = u.Data * v.Data |> PackedScalars<'a, 'm1*'m2>
end

type ScalarField<'a, [<Measure>] 'm
            when 'a: (new: unit -> 'a) 
            and  'a: struct 
            and  'a :> System.ValueType> = struct
    val public Data: 'a[]
    member self.Item with inline get i                    = Scalar<'a, 'm>(self.Data.[i])
                     and  inline set i (v: Scalar<'a,'m>) = self.Data.[i] <- v.Data
    member self.Packed 
           with inline get i                          = Vector<'a>(self.Data, i) |> PackedScalars<_,'m>
           and  inline set i (v: PackedScalars<_,'m>) = v.Data.CopyTo(self.Data, i)
    new (d:'a[]) = {Data = d}
    new (count: int) = {Data = Array.zeroCreate count}
end
Run Code Online (Sandbox Code Playgroud)

...这意味着我们基本上不是通过使用“细化”类型(如 )来跟踪单位float<'m>,而是仅通过带有辅助类型/单位参数的包装类型来跟踪单位。

不过,我仍然希望有人能提出更好的主意。:)