小编Isa*_*ugt的帖子

Flux 的自定义梯度而不是使用 Zygote AD

我有一个机器学习模型,其中模型参数的梯度是解析的,不需要自动微分。然而,我仍然希望能够利用 Flux 中的不同优化器,而不必依赖 Zygote 进行区分。这是我的代码的一些片段。

\n\n
W = rand(Nh, N)\nU = rand(N, Nh)\nb = rand(N)\nc = rand(Nh)\n\n\xce\xb8 = Flux.Params([b, c, U, W])\n\nopt = ADAM(0.01)\n
Run Code Online (Sandbox Code Playgroud)\n\n

然后我有一个函数可以计算模型参数的解析梯度,\xce\xb8

\n\n
function gradients(x) # x = one input data point or a batch of input data points\n    # stuff to calculate gradients of each parameter\n    # returns gradients of each parameter\n
Run Code Online (Sandbox Code Playgroud)\n\n

然后我希望能够做如下的事情。

\n\n
grads = gradients(x)\nupdate!(opt, \xce\xb8, grads)\n
Run Code Online (Sandbox Code Playgroud)\n\n

我的问题是:我的函数需要返回什么形式/类型gradient(x)才能执行此操作update!(opt, \xce\xb8, grads),以及如何执行此操作?

\n

flux julia

5
推荐指数
1
解决办法
836
查看次数

标签 统计

flux ×1

julia ×1