具有指定分布的随机数据中的类型不稳定性

Tim*_*imD 4 julia

我想从带有噪声的线性模型 (Y = X*w + e) 生成数据,我可以在其中指定输入向量 X 和标量噪声 e 的分布。为此,我指定了以下结构

using Distributions

struct NoisyLinearDataGenerator
    x_dist::ContinuousMultivariateDistribution
    noise_dist::ContinuousUnivariateDistribution
    weights::Vector{Float64}
end
Run Code Online (Sandbox Code Playgroud)

以及从中生成 N 个点的函数:

function generate(nl::NoisyLinearDataGenerator, N)
    x = rand(nl.x_dist, N)'
    e = rand(nl.noise_dist, N)
    return x, x*nl.weights + e
end
Run Code Online (Sandbox Code Playgroud)

这似乎有效,但类型不稳定,因为

nl = NoisyLinearDataGenerator(MvNormal(5, 1.0), Normal(), ones(5))

@code_warntype generate(nl,1)
Run Code Online (Sandbox Code Playgroud)

产量

Variables
  #self#::Core.Compiler.Const(generate, false)
  nl::NoisyLinearDataGenerator
  N::Int64
  x::Any
  e::Any

Body::Tuple{Any,Any}
1 ? %1  = Base.getproperty(nl, :x_dist)::Distribution{Multivariate,Continuous}
?   %2  = Main.rand(%1, N)::Any
?         (x = Base.adjoint(%2))
?   %4  = Base.getproperty(nl, :noise_dist)::Distribution{Univariate,Continuous}
?         (e = Main.rand(%4, N))
?   %6  = x::Any
?   %7  = x::Any
?   %8  = Base.getproperty(nl, :weights)::Array{Float64,1}
?   %9  = (%7 * %8)::Any
?   %10 = (%9 + e)::Any
?   %11 = Core.tuple(%6, %10)::Tuple{Any,Any}
???       return %11

Run Code Online (Sandbox Code Playgroud)

我不确定为什么会这样,因为我希望通过使用ContinuousMultivariateDistribution和指定采样数据的类型ContinuousUnivariateDistribution

是什么导致类型不稳定,类型稳定的实现应该是什么样的?

Jak*_*sen 6

问题在于ContinuousMultivariateDistributionContinuousUnivariateDistribution是抽象类型。虽然您的统计知识告诉您他们可能应该 return Float64,但在语言级别上并不能保证有人不会实现,例如, aContinuousUnivariateDistribution返回某个其他对象。因此编译器无法知道all ContinuousUnivariateDistribution产生任何特定类型。

例如,我可能会写:

struct BadDistribution <: ContinuousUnivariateDistribution end
Base.rand(::BadDistribution, ::Integer) = nothing
Run Code Online (Sandbox Code Playgroud)

现在,你可以制作一个NoisyLinearDataGenerator包含 a BadDistributionas x_dist。那么输出类型是什么?

换句话说,输出generate根本不能只从它的输入类型的预测。

为了解决这个问题,你需要为你的新类型指定特定的分布,或者让你的新类型参数化。在 Julia 中,每当我们有一个不能指定为具体类型的类型的字段时,我们通常将它作为类型参数。因此,一种可能的解决方案是:

using Distributions

struct NoisyLinearDataGenerator{X,N}
    x_dist::X
    noise_dist::N
    weights::Vector{Float64}

    function NoisyLinearDataGenerator{X,N}(x::X, n::N, w::Vector{Float64}) where {
                                    X <: ContinuousMultivariateDistribution,
                                    N <: ContinuousUnivariateDistribution}
        return new{X,N}(x,n,w)
    end
end

function NoisyLinearDataGenerator(x::X, n::N, w::Vector{Float64}) where {
                                X <: ContinuousMultivariateDistribution,
                                N <: ContinuousUnivariateDistribution}
    return NoisyLinearDataGenerator{X,N}(x,n,w)
end

function generate(nl::NoisyLinearDataGenerator, N)
    x = rand(nl.x_dist, N)'
    e = rand(nl.noise_dist, N)
    return x, x*nl.weights + e
end
nl = NoisyLinearDataGenerator(MvNormal(5, 1.0), Normal(), ones(5))
Run Code Online (Sandbox Code Playgroud)

在这里, 的类型nlNoisyLinearDataGenerator{MvNormal{Float64,PDMats.ScalMat{Float64},FillArrays.Zeros{Float64,1,Tuple{Base.OneTo{Int64}}}},Normal{Float64}}(是的,我知道,很难读),但它的类型包含编译器完全预测generate.