使用 Zygote 计算 Julia 中包含 NN wrt 参数的损失函数的 Hessian

Mis*_*iss 5 julia flux.jl

如何计算由神经网络和神经网络参数组成的损失函数的粗麻布矩阵?

\n

例如,考虑下面的损失函数

\n
using Flux: Chain, Dense, \xcf\x83, crossentropy, params\nusing Zygote\nmodel = Chain(\n    x -> reshape(x, :, size(x, 4)),\n    Dense(2, 5),\n    Dense(5, 1),\n    x -> \xcf\x83.(x)\n)\nn_data = 5\ninput = randn(2, 1, 1, n_data)\ntarget = randn(1, n_data)\nloss = model -> crossentropy(model(input), target)\n
Run Code Online (Sandbox Code Playgroud)\n

我可以通过两种方式获得梯度wrt参数\xe2\x80\xa6

\n
Zygote.gradient(model -> loss(model), model)\n
Run Code Online (Sandbox Code Playgroud)\n

或者

\n
grad = Zygote.gradient(() -> loss(model), params(model))\ngrad[params(model)[1]]\n
Run Code Online (Sandbox Code Playgroud)\n

但是,我可以\xe2\x80\x99t找到一种方法来获取其参数的粗麻布。(我想做类似的事情Zygote.hessian(model -> loss(model), model),但Zygote.hessian::Params作为输入)

\n

最近,主分支中添加了jacobian一个函数(问题#910),它被理解为输入::Params

\n

我一直在尝试组合gradientjacobian获得粗麻布(因为粗麻布是函数梯度的雅可比矩阵),但无济于事。\n我认为问题在于,这是model一个Chain包含通用函数的对象,例如reshape\xcf\x83.缺少参数,但我无法克服这一点。

\n
grad = model -> Zygote.gradient(model -> loss(model), model)\njacob = model -> Zygote.jacobian(grad, model)\njacob(model) ## does not work\n
Run Code Online (Sandbox Code Playgroud)\n

编辑:仅供参考,我之前在 pytorch 中创建过这个

\n

pat*_*alt 0

不确定这是否对您的特定用例有帮助,但您可以使用 Hessian 的近似值,例如经验费希尔 (EF)。受此PyTorch 实现的启发,我使用这种方法来实现 Flux 模型的拉普拉斯近似(请参阅此处)。下面我已将该方法应用于您的示例。

\n
using Flux: Chain, Dense, \xcf\x83, crossentropy, params, DataLoader\nusing Zygote\nusing Random\n\nRandom.seed!(2022)\nmodel = Chain(\n    x -> reshape(x, :, size(x, 4)),\n    Dense(2, 5),\n    Dense(5, 1),\n    x -> \xcf\x83.(x)\n)\nn_data = 5\ninput = randn(2, 1, 1, n_data)\ntarget = randn(1, n_data)\nloss(x, y) = crossentropy(model(x), y)\n\nn_params = length(reduce(vcat, [vec(\xce\xb8) for \xce\xb8 \xe2\x88\x88 params(model)]))\n = zeros(n_params,n_params)\ndata = DataLoader((input, target))\n\nfor d in data\n  x, y = d\n   = gradient(() -> loss(x,y),params(model))  \n   = reduce(vcat,[vec([\xce\xb8]) for \xce\xb8 \xe2\x88\x88 params(model)])\n   +=  * \' # empirical fisher\nend\n
Run Code Online (Sandbox Code Playgroud)\n

如果有一种方法可以直接使用 Zygote autodiff(并且更有效),我也有兴趣看到这一点。将 EF 用于完整的 Hessian 矩阵仍然会在参数数量上呈二次方缩放,但如这篇 NeurIPS 2021论文中所示,您可以使用(博客)对角分解进一步近似 Hessian 矩阵。该论文还表明,在贝叶斯深度学习的背景下,仅概率性地处理最后一层通常会产生良好的结果,但再次不确定是否与您的情况相关。

\n