我有一个机器学习模型,其中模型参数的梯度是解析的,不需要自动微分。然而,我仍然希望能够利用 Flux 中的不同优化器,而不必依赖 Zygote 进行区分。这是我的代码的一些片段。
\n\nW = rand(Nh, N)\nU = rand(N, Nh)\nb = rand(N)\nc = rand(Nh)\n\n\xce\xb8 = Flux.Params([b, c, U, W])\n\nopt = ADAM(0.01)\nRun Code Online (Sandbox Code Playgroud)\n\n然后我有一个函数可以计算模型参数的解析梯度,\xce\xb8。
function gradients(x) # x = one input data point or a batch of input data points\n # stuff to calculate gradients of each parameter\n # returns gradients of each parameter\nRun Code Online (Sandbox Code Playgroud)\n\n然后我希望能够做如下的事情。
\n\ngrads = gradients(x)\nupdate!(opt, \xce\xb8, grads)\nRun Code Online (Sandbox Code Playgroud)\n\n我的问题是:我的函数需要返回什么形式/类型gradient(x)才能执行此操作update!(opt, \xce\xb8, grads),以及如何执行此操作?